blob: 5138e3f8031af98c6dd9d5535f6a3fb3af87595f [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
35
Kevin Cheng550ccc52021-03-03 11:21:43 -080036# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
37parent_dir = os.path.dirname(os.path.realpath(__file__))
38sys.path.append(
39 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
40)
Eric Kunzee5e26762020-10-13 16:11:07 -070041import tosa_serializer as ts
42from tosa_serializer import *
43import tosa
44
45# Convenience variables to the flatc-generated types that should be enums, but aren't
46DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080047Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070048ResizeMode = tosa.ResizeMode.ResizeMode()
49
Kevin Cheng550ccc52021-03-03 11:21:43 -080050
Eric Kunzee5e26762020-10-13 16:11:07 -070051class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080052 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
53
Eric Kunzee5e26762020-10-13 16:11:07 -070054 def __init__(self):
55 pass
56
57 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010058 def getQinfo(testGen, dtype):
59 if dtype == DType.INT8:
60 return testGen.randInt(-128, 128)
61 if dtype == DType.UINT8:
62 return testGen.randInt(0, 256)
63 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070064
65 @staticmethod
66 def qgUnary(testGen, op, dtype):
67 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010068 qinfo.UnaryQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
69 TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070070 return qinfo
71
72 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010073 def qgConv(testGen, op, dtype_or_dtypeList):
Eric Kunzee5e26762020-10-13 16:11:07 -070074 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010075 if isinstance(dtype_or_dtypeList, list):
76 # a list of [input, weights, accumulator] dtypes
77 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070078 else:
Les Bell30e46802021-07-23 09:43:31 +010079 # an int, [input, weights, accumulator] dtypes are the same
80 dtypeList = [dtype_or_dtypeList] * 3
81 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
82 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
83 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -070084 return qinfo
85
86 @staticmethod
87 def qgMatmul(testGen, op, dtype):
88 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010089 qinfo.MatMulQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
90 TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070091 return qinfo
92
93 @staticmethod
94 def qgPad(testGen, op, dtype):
95 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010096 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070097 return qinfo
98
99 @staticmethod
100 def computeMultiplierAndShift(scaleFp, scale32):
101 # Derived from computeMultiplierAndShiftTosaScale32
102 # Provide a floating-point scaling factor and the scale32 parameter
103 # to compute the multiplier and shift
104
105 if scale32:
106 scaleBits = 31
107 else:
108 scaleBits = 15
109
110 m, shift = math.frexp(scaleFp)
111
112 if scaleFp < 0.0:
113 m = -m
114
115 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800116 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700117
118 if multiplier == (1 << scaleBits):
119 multiplier = multiplier // 2
120 shift = shift + 1
121
122 shift = (-shift) + scaleBits
Kevin Cheng550ccc52021-03-03 11:21:43 -0800123 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
Eric Kunzee5e26762020-10-13 16:11:07 -0700124
Kevin Cheng550ccc52021-03-03 11:21:43 -0800125 assert multiplier <= (1 << scaleBits)
126 assert shift >= 0 and shift <= 63
Eric Kunzee5e26762020-10-13 16:11:07 -0700127
128 return multiplier, shift
129
130
Kevin Cheng550ccc52021-03-03 11:21:43 -0800131class TosaTensorGen:
132 """Tensor generators create a shape list for the placeholder and const tensor
133 data operands for the operator. The actual random data is generated separately for each test."""
134
Eric Kunzee5e26762020-10-13 16:11:07 -0700135 def __init__(self):
136 pass
137
138 @staticmethod
139 def tgBasic(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800140 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700141 shape = testGen.makeShape(rank)
142
143 shape_list = []
144 for i in range(pl + const):
145 shape_list.append(shape.copy())
146
147 return shape_list
148
149 @staticmethod
150 def tgNHWC(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800151 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700152
Kevin Cheng550ccc52021-03-03 11:21:43 -0800153 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700154
155 shape = testGen.makeShape(rank)
156
157 # Constrict the batch size?
158 if testGen.args.max_batch_size:
159 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
160
161 shape_list = []
162 for i in range(pl + const):
163 shape_list.append(shape.copy())
164
165 return shape_list
166
167 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -0800168 def tgScatter(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800169 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800170
Kevin Cheng550ccc52021-03-03 11:21:43 -0800171 assert pl == 2
172 assert const == 0
173 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800174
175 values_in_shape = testGen.makeShape(rank)
176
177 # Constrict the batch size?
178 if testGen.args.max_batch_size:
179 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
180
Kevin Cheng550ccc52021-03-03 11:21:43 -0800181 W = testGen.randInt(
182 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
183 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800184 input_shape = [values_in_shape[0], W, values_in_shape[2]]
185
186 shape_list = []
187 shape_list.append(values_in_shape.copy())
188 shape_list.append(input_shape.copy())
189
190 return shape_list
191
192 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 def tgBroadcastFuzz(testGen, op, rank):
194 shape = testGen.makeShape(rank)
195
Kevin Cheng550ccc52021-03-03 11:21:43 -0800196 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700197
198 shape_list = []
199
200 # Choose one of the inputs to broadcast
201 bcast_idx = testGen.randInt(0, pl + const)
202 for i in range(pl + const):
203 shape_bcast = shape.copy()
204
205 # If the chosen input, pick a random index to broadcast
206 if i == bcast_idx:
207 fuzz_idx = testGen.randInt(0, rank)
208 shape_bcast[fuzz_idx] = 1
209
210 shape_list.append(shape_bcast)
211
212 return shape_list
213
214 @staticmethod
215 def tgConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800216 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700217
Kevin Cheng550ccc52021-03-03 11:21:43 -0800218 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700219
220 # IFM dimensions are NHWC
221 ifm_shape = testGen.makeShape(rank)
222
223 # Constrict the batch size?
224 if testGen.args.max_batch_size:
225 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
226
227 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800228 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700229
230 # Generate a random OFM depth
231 ofm_depth = testGen.makeShape(1)[0]
232
233 # The filter dimensions are OHWI
234 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
235
236 # The bias is OC
237 bias_shape = np.asarray([ofm_depth])
238
239 return [ifm_shape, filter_shape, bias_shape]
240
241 @staticmethod
242 def tgTransposeConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800243 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700244
Kevin Cheng550ccc52021-03-03 11:21:43 -0800245 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700246
247 # IFM dimensions are NHWC
248 ifm_shape = testGen.makeShape(rank)
249
250 # Constrict the batch size?
251 if testGen.args.max_batch_size:
252 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
253
254 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800255 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700256
257 # Generate a random OFM depth
258 ofm_depth = testGen.makeShape(1)[0]
259
260 # The filter dimensions are OHWI
261 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
262
Kevin Cheng989cb052021-04-28 16:29:44 -0700263 # The bias is OC
264 bias_shape = np.asarray([ofm_depth])
265
266 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700267
268 @staticmethod
269 def tgDepthwiseConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800270 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700271
Kevin Cheng550ccc52021-03-03 11:21:43 -0800272 assert rank == 4
273 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
275 # IFM dimensions are NHWC
276 ifm_shape = testGen.makeShape(rank)
277
278 # Constrict the batch size?
279 if testGen.args.max_batch_size:
280 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
281
282 # Get the filter height/width from the operator parameters
283 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800284 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
286 # Generate a random OFM depth, but don't let it get too big because
287 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800288 filter_m = (
289 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
290 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700291
292 # The filter dimensions are HWCM
293 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
294
295 # The bias is M * C
296 bias_shape = np.asarray([ifm_shape[3] * filter_m])
297
298 return [ifm_shape, filter_shape, bias_shape]
299
300 @staticmethod
301 def tgFullyConnected(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800302 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700303
Kevin Cheng550ccc52021-03-03 11:21:43 -0800304 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700305
306 input_shape = testGen.makeShape(rank)
Les Bell2a29dc62021-07-28 08:04:55 +0100307 filter_oc = (
308 testGen.rng.integers(
309 low=testGen.args.tensor_shape_range[0],
310 high=testGen.args.tensor_shape_range[1],
311 size=1,
312 )[0]
313 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700314 filter_shape = np.asarray([filter_oc, input_shape[1]])
315
316 bias_shape = np.asarray([filter_oc])
317
318 return [input_shape, filter_shape, bias_shape]
319
320 @staticmethod
321 def tgMatmul(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800322 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700323
Kevin Cheng2d60f002021-06-09 14:18:32 -0700324 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800325 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700326
327 a_shape = testGen.makeShape(rank)
328 b_oc = testGen.makeShape(1)[0]
Kevin Cheng2d60f002021-06-09 14:18:32 -0700329 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700330
331 return [a_shape, b_shape]
332
Matthew Haddon818ab902021-07-27 09:12:49 +0100333 @staticmethod
334 def tgConcat(testGen, opName, rank):
335 pl, const = opName["operands"]
336 shape = testGen.makeShape(rank)
337
338 # Create extra tensors to concat.
339 # Take into account value of pl when getting maximum number of concats
340 num_tensors = testGen.randInt(0, 4)
341 shape_list = []
342 for i in range(pl + const + num_tensors):
343 shape_list.append(shape.copy())
344
345 return shape_list
346
347 @staticmethod
348 def tgConcatConstInput(testGen, shapeList, axis):
349 # Split concat shape along axis to allow for multiple const inputs
350 # without making too many large tensors
351 shape = shapeList[0]
352 if len(shapeList) == 2 or shape[axis] < len(shapeList):
353 return shapeList
354
355 new_shapeList = [shape.copy()]
356 length_on_axis = shape[axis]
357 remaining_length = length_on_axis
358 for i in range(len(shapeList)-2):
359 # Calculate split on axis and remaining value
360 split_shape_val = int(shape[axis] / 2)
361 remaining_length = remaining_length - split_shape_val
362
363 # Append new shape, and set remaining shape
364 shape[axis] = split_shape_val
365 new_shapeList.append(shape.copy())
366 shape[axis] = remaining_length
367 if i == len(shapeList) - 3:
368 new_shapeList.append(shape.copy())
369
370 return new_shapeList
371
372
Kevin Cheng550ccc52021-03-03 11:21:43 -0800373
Eric Kunzee5e26762020-10-13 16:11:07 -0700374class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800375 """Argument generators create exhaustive or random lists of attributes for operators that take
376 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
377 tuples where the descriptive_name is appended to the test name and the arglist is expanded
378 as arguments to the operator build function."""
379
Eric Kunzee5e26762020-10-13 16:11:07 -0700380 def __init__(self):
381 pass
382
383 @staticmethod
384 def agNone(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800385 """A trivial argument generator for operators that don't take any
386 non-tensor arguments"""
387 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700388
389 @staticmethod
390 def agAxis(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800391 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700392 axes = []
393
394 shape = shapeList[0]
395
396 for a in range(0, len(shape)):
Matthew Haddon43e37192021-07-09 14:13:02 +0100397 axes.append(("axis{}".format(a), [a]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700398 return axes
399
400 @staticmethod
401 def agConv2D(testGen, opName, shapeList, dtype):
402 arg_list = []
403
404 ifm_shape = shapeList[0]
405 filter_shape = shapeList[1]
406
407 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800408 assert len(ifm_shape) == 4
409 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700410
411 maxStride = testGen.args.max_conv_stride
412 maxPadding = testGen.args.max_conv_padding + 1
413 maxDilation = testGen.args.max_conv_dilation
414
415 # Strides, padding, dilations
416 for stride in range(0, maxStride ** 2):
417 for padding in range(0, (maxPadding) ** 4):
418 for dilation in range(0, maxDilation ** 2):
419
Kevin Cheng550ccc52021-03-03 11:21:43 -0800420 s = [stride // maxStride + 1, stride % maxStride + 1]
421 p = [
422 (padding // (maxPadding * 4)) % maxPadding,
423 (padding // (maxPadding * 2)) % maxPadding,
424 (padding // (maxPadding * 1)) % maxPadding,
425 padding % maxPadding,
426 ]
427 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700428
429 # 4 padding parameters for regular conv2d
Kevin Cheng550ccc52021-03-03 11:21:43 -0800430 arg_list.append(
431 (
432 "st{}{}_pad{}{}{}{}_dilat{}{}".format(
433 s[0], s[1], p[0], p[1], p[2], p[3], d[0], d[1]
434 ),
435 [s, p, d],
436 )
437 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700438 return arg_list
439
440 @staticmethod
441 def agTransposeConv2D(testGen, opName, shapeList, dtype):
442 arg_list = []
443
444 ifm_shape = shapeList[0]
445 filter_shape = shapeList[1]
446
447 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800448 assert len(ifm_shape) == 4
449 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700450
451 maxStride = testGen.args.max_conv_stride
452 maxPadding = testGen.args.max_conv_padding + 1
453 maxDilation = testGen.args.max_conv_dilation
454
455 # Strides, padding, dilations
456 for stride in range(0, maxStride ** 2):
457 for out_padding in range(0, (maxPadding) ** 2):
458 for dilation in range(0, maxDilation ** 2):
459
Kevin Cheng550ccc52021-03-03 11:21:43 -0800460 s = [stride // maxStride + 1, stride % maxStride + 1]
461 p = [
462 (out_padding // (maxPadding * 1)) % maxPadding,
463 out_padding % maxPadding,
464 ]
465 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700466
Kevin Cheng550ccc52021-03-03 11:21:43 -0800467 oh = (
468 ifm_shape[1]
469 - filter_shape[1]
470 - (filter_shape[1] - 1) * (d[0] - 1)
471 + 2 * p[0]
472 ) // s[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700473
Kevin Cheng550ccc52021-03-03 11:21:43 -0800474 ow = (
475 ifm_shape[2]
476 - filter_shape[2]
477 - (filter_shape[2] - 1) * (d[1] - 1)
478 + 2 * p[1]
479 ) // s[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700480
481 # Output shape
Kevin Cheng550ccc52021-03-03 11:21:43 -0800482 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Eric Kunzee5e26762020-10-13 16:11:07 -0700483
Kevin Cheng550ccc52021-03-03 11:21:43 -0800484 arg_list.append(
485 (
486 "st{}{}_outpad{}{}_dilat{}{}_os{}x{}x{}x{}".format(
487 s[0],
488 s[1],
489 p[0],
490 p[1],
491 d[0],
492 d[1],
493 os[0],
494 os[1],
495 os[2],
496 os[3],
497 ),
498 [s, p, d, os],
499 )
500 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700501
502 return arg_list
503
504 @staticmethod
505 def agPad(testGen, opName, shapeList, dtype):
506 arg_list = []
507 rank = len(shapeList[0])
508
Les Bell7ffccce2021-07-28 15:37:02 +0100509 # Exhaustively test combinations of padding on each side of each dimension
510 # - the range of padding values is defined by pad_min and pad_max
511 # - for padding >9, the name format needs to be more distinctive
512 pad_min, pad_max = 0, 1
513 pad_values = [x for x in range(pad_min, pad_max + 1)]
514 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
515 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700516
Les Bell7ffccce2021-07-28 15:37:02 +0100517 for paddings in shape_pad_values:
518 name = "pad"
519 for r in range(rank):
520 before, after = paddings[r]
521 name = f"{name}{before}{after}"
522 arg_list.append((name, [np.array(paddings)]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700523
524 return arg_list
525
526 @staticmethod
527 def agPooling(testGen, opName, shapeList, dtype):
528 arg_list = []
529
530 shape = shapeList[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800531 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700532
533 maxStride = testGen.args.max_pooling_stride
534 maxKernel = testGen.args.max_pooling_kernel
535 maxPadding = testGen.args.max_pooling_padding + 1
536
537 for kernel in range(0, maxKernel ** 2):
538 for stride in range(0, maxStride ** 2):
539 for padding in range(0, maxPadding ** 4):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800540 s = [stride // maxStride + 1, stride % maxStride + 1]
541 k = [(kernel // maxKernel) + 2, (kernel % maxKernel) + 2]
542 p = [
543 (padding // (maxPadding * 4)) % maxPadding,
544 (padding // (maxPadding * 2)) % maxPadding,
545 (padding // (maxPadding * 1)) % maxPadding,
546 padding % maxPadding,
547 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700548
Kevin Cheng550ccc52021-03-03 11:21:43 -0800549 arg_list.append(
550 (
551 "st{}{}_kern{}{}_pad{}{}{}{}".format(
552 s[0], s[1], k[0], k[1], p[0], p[1], p[2], p[3]
553 ),
554 [k, s, p],
555 )
556 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700557 return arg_list
558
559 @staticmethod
560 def agCast(testGen, opName, shapeList, inDtype):
561 arg_list = []
562
563 # Enumerate the output types here
564 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800565 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700566 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800567 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700568 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800569 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700570 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800571 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700572 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800573 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700574 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800575 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700576
577 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800578 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700579
580 return arg_list
581
582 @staticmethod
583 def agRescale(testGen, opName, shapeList, inDtype):
584 arg_list = []
585
586 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100587 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
588 if inDtype == DType.UINT8 and dtype != DType.INT8:
589 # The only output dtype for UINT8 is INT8, skip all other combinations
590 continue
591 if inDtype != DType.INT8 and dtype == DType.UINT8:
592 # The only input dtype for UINT8 is INT8, skip all other combinations
593 continue
594
Kevin Cheng550ccc52021-03-03 11:21:43 -0800595 for scale32 in [False, True]:
596 for double_round in [False, True]:
597 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700598
599 if inDtype == DType.INT48 and scale32:
600 # Illegal condition. Must be scale32=False
601 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100602 if double_round and not scale32:
603 # Illegal condition. ERROR_IF(!scale32 && double_round)
604 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700605
Kevin Cheng550ccc52021-03-03 11:21:43 -0800606 arg_list.append(
607 (
608 "out{}_sc{}_dr{}_pc{}".format(
609 DTypeNames[dtype],
610 int(scale32),
611 int(double_round),
612 int(per_channel),
613 ),
614 [dtype, scale32, double_round, per_channel],
615 )
616 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700617
618 return arg_list
619
Kevin Chengaee1fac2020-11-11 13:54:06 -0800620 @staticmethod
621 def agMul(testGen, opName, shapeList, dtype):
622 arg_list = []
623
624 if dtype is DType.INT32:
625 for p in range(testGen.args.num_rand_permutations):
626
627 shift = testGen.randInt(0, 32)
628
Kevin Cheng550ccc52021-03-03 11:21:43 -0800629 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800630 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100631 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800632
633 return arg_list
634
635 @staticmethod
636 def agArithmeticRightShift(testGen, opName, shapeList, dtype):
637 arg_list = []
638
Kevin Cheng550ccc52021-03-03 11:21:43 -0800639 arg_list.append(("roundTrue", [True]))
640 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800641
642 return arg_list
643
Eric Kunzee5e26762020-10-13 16:11:07 -0700644 # Helper function for reshape. Gets some factors of a larger number.
645 @staticmethod
646 def getFactors(val, start=1):
647 factors = []
648
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100649 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700650 if (val % i) == 0:
651 factors.append(i)
652
653 return factors
654
655 @staticmethod
656 def agReshape(testGen, opName, shapeList, dtype):
657 arg_list = []
658
659 origShape = shapeList[0]
660
661 totalElements = 1
662 for s in origShape:
663 totalElements *= s
664
665 # This code is NOT fast. Fortunately, the numbers are fairly small.
666 factors = TosaArgGen.getFactors(totalElements)
667
668 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100669 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800670 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700671 continue
672
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100673 found = True
674 # escape_counter breaks while loop if it continues on for too long
675 escape_counter = 0
676 while found:
677 newShape = []
678 # Generate newShape ensuring it isn't a duplicate
679 remainingElements = totalElements
680 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100681 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100682 # pick rank-1 factors
683 newShape.append(shuffledFactors[0])
684 remainingElements = remainingElements // shuffledFactors[0]
685 shuffledFactors = testGen.rng.permutation(
686 TosaArgGen.getFactors(remainingElements)
687 )
688 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700689
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100690 # Toss in a -1 sometimes
691 minusOne = testGen.randInt(0, newRank * 4)
692 if minusOne < newRank:
693 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700694
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100695 # Check for duplicates
696 found = False
697 for name, other_shape in arg_list:
698 if other_shape[0] == newShape:
699 found = True
700 break
701
702 escape_counter += 1
703 if escape_counter >= 100:
704 break
705
706 if not found:
707 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700708
709 return arg_list
710
Eric Kunzee5e26762020-10-13 16:11:07 -0700711 @staticmethod
712 def agTranspose(testGen, opName, shapeList, dtype):
713 arg_list = []
714
715 ifm_shape = shapeList[0]
716
Jeremy Johnsona6185572021-06-21 15:55:35 +0100717 # Get all permutations
718 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700719
Jeremy Johnsona6185572021-06-21 15:55:35 +0100720 # Limit to possible permutations from shape dimension or argument setting
721 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700722
Jeremy Johnsona6185572021-06-21 15:55:35 +0100723 # Get random permutation generator that uses all permutations
724 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700725
Jeremy Johnsona6185572021-06-21 15:55:35 +0100726 # Create list of required amount of permutations
727 arg_list = [("perm{}".format(p), [random_permutations[p].tolist()]) for p in range(limit)]
Eric Kunzee5e26762020-10-13 16:11:07 -0700728 return arg_list
729
730 @staticmethod
731 def agSlice(testGen, opName, shapeList, dtype):
732 arg_list = []
733
734 ifm_shape = shapeList[0]
735 rank = len(ifm_shape)
736
737 for p in range(testGen.args.num_rand_permutations):
738 begin = []
739 size = []
740
Kevin Cheng550ccc52021-03-03 11:21:43 -0800741 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700742
743 for i in range(rank):
744 if ifm_shape[i] > 1:
745 begin.append(testGen.randInt(0, ifm_shape[i]))
746 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
747
748 # Invalid slice size?
749 if size[i] == 0:
750 valid = False
751 else:
752 begin.append(0)
753 size.append(1)
754
755 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800756 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700757 return arg_list
758
759 @staticmethod
760 def agTile(testGen, opName, shapeList, dtype):
761 arg_list = []
762
763 ifm_shape = shapeList[0]
764 rank = len(ifm_shape)
765
766 for p in range(testGen.args.num_rand_permutations):
767
768 # Pick a few random, but small multiple values
769 # because otherwise this has a tendency to generate
770 # enormous tensors
771 multiples = []
772 for i in range(rank):
773 multiples.append(testGen.randInt(1, 4))
774
Kevin Cheng550ccc52021-03-03 11:21:43 -0800775 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700776
777 return arg_list
778
779 @staticmethod
780 def agResize(testGen, opName, shapeList, dtype):
781 arg_list = []
782
783 ifm_shape = shapeList[0]
784
785 for m in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
786
787 # Exclude illegal {mode, type} configurations. Pick legal output types
788 if m == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100789 outputDTypeList = [DType.INT8]
Eric Kunzee5e26762020-10-13 16:11:07 -0700790 elif m == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800791 outputDTypeList = [DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -0700792 elif m == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100793 outputDTypeList = [DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700794 elif m == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800795 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800796 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800797 outputDTypeList = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700798 else:
799 continue
800
801 for outputDType in outputDTypeList:
802 for perm in range(testGen.args.num_rand_permutations):
803
804 # Randomly generate legal output dimensions and shift
805 # and then compute the stride and offset based on them
Kevin Cheng550ccc52021-03-03 11:21:43 -0800806 output_dims = [testGen.randInt(1), testGen.randInt(1)]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800807 in_center_h = (ifm_shape[1] - 1) / 2.0
808 in_center_w = (ifm_shape[2] - 1) / 2.0
809 out_center_h = (output_dims[0] - 1) / 2.0
810 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700811
Kevin Cheng77d0f762020-11-24 10:26:32 -0800812 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
813 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
814 fp_offset_y = in_center_h - fp_stride_y * out_center_h
815 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -0700816
Kevin Cheng77d0f762020-11-24 10:26:32 -0800817 if outputDType == DType.FLOAT:
818 shift = 0
819 stride = [0, 0]
820 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800821 stride_fp = [fp_stride_y, fp_stride_x]
822 offset_fp = [fp_offset_y, fp_offset_x]
823 arg_list.append(
824 (
825 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Les Bell33d837e2021-08-10 08:34:43 +0100826 "N" if m == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800827 output_dims[0],
828 output_dims[1],
829 testGen.typeStr(outputDType),
830 stride_fp[0],
831 stride_fp[1],
832 offset_fp[0],
833 offset_fp[1],
834 ),
835 [
836 m,
837 stride,
838 offset,
839 shift,
840 stride_fp,
841 offset_fp,
842 output_dims,
843 dtype,
844 outputDType,
845 ],
846 )
847 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800848 else:
849 shift = 11
850 unit = float(1 << shift)
851 stride_y = int(round(fp_stride_y * unit))
852 stride_x = int(round(fp_stride_x * unit))
853 offset_y = int(round(fp_offset_y * unit))
854 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700855
Kevin Cheng550ccc52021-03-03 11:21:43 -0800856 while (
857 stride_y >= 32768
858 or stride_x >= 32768
859 or offset_y >= 32768
860 or offset_x >= 32768
861 or offset_y < -32768
862 or offset_x < -32768
863 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -0800864 shift = shift - 1
865 unit = float(1 << shift)
866 stride_y = int(round(fp_stride_y * unit))
867 stride_x = int(round(fp_stride_x * unit))
868 offset_y = int(round(fp_offset_y * unit))
869 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700870
Kevin Cheng550ccc52021-03-03 11:21:43 -0800871 stride = [stride_y, stride_x]
872 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800873
874 stride_fp = [0.0, 0.0]
875 offset_fp = [0.0, 0.0]
876
Kevin Cheng550ccc52021-03-03 11:21:43 -0800877 arg_list.append(
878 (
879 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Les Bell33d837e2021-08-10 08:34:43 +0100880 "N" if m == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800881 shift,
882 output_dims[0],
883 output_dims[1],
884 testGen.typeStr(outputDType),
885 stride[0],
886 stride[1],
887 offset[0],
888 offset[1],
889 ),
890 [
891 m,
892 stride,
893 offset,
894 shift,
895 stride_fp,
896 offset_fp,
897 output_dims,
898 dtype,
899 outputDType,
900 ],
901 )
902 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700903
904 return arg_list
905
906 def agCondIf(testGen, opName, shapeList, dtype):
907 # CondIf generates the condition values here.
908 # Convert to tensors in the build function, along with the
909 # then and else blocks
910 arg_list = []
911
912 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800913 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700914
915 return arg_list
916
917 def agWhileLoop(testGen, opName, shapeList, dtype):
918 # While loop: 0 iterations, 1, more than 1
919 arg_list = []
920
921 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800922 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700923
924 return arg_list
925
Kevin Cheng550ccc52021-03-03 11:21:43 -0800926
Eric Kunzee5e26762020-10-13 16:11:07 -0700927class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +0100928 # Maximum rank of tensor supported by test generator.
929 TOSA_TENSOR_MAX_RANK = 6
930
Eric Kunzee5e26762020-10-13 16:11:07 -0700931 def __init__(self, args):
932 self.args = args
933 self.basePath = args.output_dir
934 self.random_seed = args.random_seed
935 self.ser = None
936 self.rng = np.random.default_rng(self.random_seed)
937 self.createDynamicOpLists()
938 self.initOpListDefaults()
939 self.quantGen = TosaQuantGen()
940 # Force makeShape to do a specific starting shape
941 self.targetted_shape = None
942
943 def createSerializer(self, opName, testPath):
944 self.testPath = os.path.join(opName, testPath)
945
946 fullPath = os.path.join(self.basePath, self.testPath)
947 os.makedirs(fullPath, exist_ok=True)
948 self.ser = ts.TosaSerializer(fullPath)
949
950 def getSerializer(self):
951 return self.ser
952
953 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800954 with open(
955 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
956 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700957 fd.write(self.ser.serialize())
958
Kevin Cheng550ccc52021-03-03 11:21:43 -0800959 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
960 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700961
962 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -0700963 if dtype == DType.BOOL:
964 np_dt = np.bool
965 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -0700966 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700967 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700968 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700969 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100970 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
971 elif dtype == DType.UINT8:
972 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700973 elif dtype == DType.INT16:
974 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
975 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800976 return np.int32(
977 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
978 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700979 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800980 return np.int64(
981 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
982 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700983 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100984 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700985 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800986 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700987
Kevin Cheng989cb052021-04-28 16:29:44 -0700988 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700989 placeholders = []
990
Kevin Cheng989cb052021-04-28 16:29:44 -0700991 assert len(shape_list) == len(dtype_list)
992
993 for idx, shape in enumerate(shape_list):
994 arr = self.getRandTensor(shape, dtype_list[idx])
995 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700996
997 return placeholders
998
Kevin Cheng989cb052021-04-28 16:29:44 -0700999 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07001000 consts = []
1001
Kevin Cheng989cb052021-04-28 16:29:44 -07001002 assert len(shape_list) == len(dtype_list)
1003
1004 for idx, shape in enumerate(shape_list):
1005 arr = self.getRandTensor(shape, dtype_list[idx])
1006 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001007
1008 return consts
1009
1010 def makeShape(self, rank):
1011 if self.targetted_shape:
1012 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001013 return np.int32(
1014 self.rng.integers(
1015 low=self.args.tensor_shape_range[0],
1016 high=self.args.tensor_shape_range[1],
1017 size=rank,
1018 )
1019 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001020
1021 def setTargetShape(self, shape):
1022 self.targetted_shape = shape
1023
1024 def randInt(self, low=0, high=256):
1025 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
1026
1027 def getRandNumberDType(self, dtype):
1028 if dtype == DType.FLOAT:
1029 return self.rng.random()
1030 elif dtype == DType.BOOL:
1031 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07001032 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07001033 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07001034 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001035 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001036 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07001037 elif dtype == DType.INT16:
1038 low, high = (-32768, 32768)
1039 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001040 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07001041 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001042 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07001043 # Special size
1044 return np.int64(self.rng.integers(low, high, size=1))[0]
1045 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001046 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001047
1048 return np.int32(self.rng.integers(low, high, size=1))[0]
1049
1050 def shapeStr(self, shape):
1051
1052 sStr = []
1053 # Convert to strings
1054 for i in shape:
1055 sStr.append(str(i))
1056
Kevin Cheng550ccc52021-03-03 11:21:43 -08001057 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001058
1059 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07001060 if isinstance(t, list):
1061 assert len(t) >= 2
1062 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001063 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001064 if t == DType.BOOL:
1065 return "b"
1066 elif t == DType.INT4:
1067 return "i4"
1068 elif t == DType.INT8:
1069 return "i8"
1070 elif t == DType.UINT8:
1071 return "u8"
1072 elif t == DType.INT16:
1073 return "i16"
1074 elif t == DType.INT32:
1075 return "i32"
1076 elif t == DType.INT48:
1077 return "i48"
1078 elif t == DType.FLOAT:
1079 return "float"
1080 else:
1081 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001082
1083 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001084 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08001085 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07001086 return 4
1087 elif t == DType.INT8:
1088 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08001089 elif t == DType.UINT8:
1090 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07001091 elif t == DType.INT16:
1092 return 16
1093 elif t == DType.INT32:
1094 return 32
1095 elif t == DType.INT48:
1096 return 48
1097 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001098 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001099
1100 # Argument generators
1101 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
1102 # Where the string descriptor is used to generate the test name and
1103 # The build_fcn_arg_list is expanded and passed to the operator test
1104 # build function
1105
Kevin Cheng550ccc52021-03-03 11:21:43 -08001106 def build_unary(self, op, a, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001107 result_tens = OutputShaper.unaryOp(self.ser, a)
1108 self.ser.addOperator(op, [a.name], [result_tens.name], None, qinfo)
1109 return result_tens
1110
1111 def build_binary_broadcast(self, op, a, b):
1112 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1113 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1114 return result_tens
1115
1116 def build_binary_nonbroadcast(self, op, a, b):
1117 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
1118 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1119 return result_tens
1120
Kevin Chengaee1fac2020-11-11 13:54:06 -08001121 def build_arithmetic_right_shift(self, op, a, b, round):
1122 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1123
1124 attr = ts.TosaSerializerAttribute()
1125 attr.ArithmeticRightShiftAttribute(round)
1126
1127 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1128 return result_tens
1129
1130 def build_mul(self, op, a, b, shift):
Eric Kunzee5e26762020-10-13 16:11:07 -07001131 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1132
1133 # Special for multiply:
1134 # Force the result to INT32 for INT types
1135 if a.dtype != DType.FLOAT:
1136 result_tens.setDtype(DType.INT32)
1137
Kevin Chengaee1fac2020-11-11 13:54:06 -08001138 attr = ts.TosaSerializerAttribute()
1139 attr.MulAttribute(shift)
1140
1141 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001142 return result_tens
1143
1144 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001145 # Constant size depending on type, random values
1146 if a.dtype == DType.INT16:
1147 table_dtype = DType.INT16
1148 table_arr = self.getRandTensor([513], table_dtype)
1149 else:
1150 assert a.dtype == DType.INT8
1151 table_dtype = DType.INT8
1152 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001153
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001154 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
1155 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001156 self.ser.addOperator(op, [a.name, table_tens.name], [result_tens.name], None)
1157
1158 return result_tens
1159
1160 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07001161 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
1162 self.ser.addOperator(op, [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001163 return result_tens
1164
1165 def build_comparison(self, op, a, b):
1166 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
1167 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1168 return result_tens
1169
1170 def build_argmax(self, op, a, axis):
1171 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
1172
1173 attr = ts.TosaSerializerAttribute()
1174 attr.AxisAttribute(axis)
1175
1176 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1177 return result_tens
1178
Kevin Cheng550ccc52021-03-03 11:21:43 -08001179 def build_pool2d(self, op, input, kernel, stride, pad, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001180 result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
1181
1182 attr = ts.TosaSerializerAttribute()
1183 attr.Pool2dAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07001184
1185 self.ser.addOperator(op, [input.name], [result_tens.name], attr, qinfo)
1186 return result_tens
1187
1188 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001189 assert len(padding) == 4
1190 result_tens = OutputShaper.conv2dOp(
1191 self.ser, ifm, filter, strides, padding, dilations
1192 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001193
1194 attr = ts.TosaSerializerAttribute()
1195 attr.Conv2dAttribute(padding, strides, dilations)
1196
Kevin Cheng550ccc52021-03-03 11:21:43 -08001197 self.ser.addOperator(
1198 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1199 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001200 return result_tens
1201
Kevin Cheng550ccc52021-03-03 11:21:43 -08001202 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07001203 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001204 ):
1205 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07001206 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
1207
1208 attr = ts.TosaSerializerAttribute()
1209 attr.TransposeConv2DAttribute(outpad, stride, dilation, output_shape)
1210
Kevin Cheng550ccc52021-03-03 11:21:43 -08001211 self.ser.addOperator(
Kevin Cheng989cb052021-04-28 16:29:44 -07001212 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001213 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001214 return result_tens
1215
Kevin Cheng550ccc52021-03-03 11:21:43 -08001216 def build_depthwise_conv2d(
1217 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
1218 ):
1219 result_tens = OutputShaper.depthwiseConv2dOp(
1220 self.ser, ifm, filter, strides, padding, dilations
1221 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001222
1223 attr = ts.TosaSerializerAttribute()
1224 attr.Conv2dAttribute(padding, strides, dilations)
1225
Kevin Cheng550ccc52021-03-03 11:21:43 -08001226 self.ser.addOperator(
1227 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1228 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001229 return result_tens
1230
1231 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
1232 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
1233
Kevin Cheng550ccc52021-03-03 11:21:43 -08001234 self.ser.addOperator(
1235 op, [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
1236 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001237 return result_tens
1238
1239 def build_matmul(self, op, a, b, qinfo):
1240 result_tens = OutputShaper.matmulOp(self.ser, a, b)
1241 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], None, qinfo)
1242 return result_tens
1243
1244 def build_reduce(self, op, a, axis):
1245 result_tens = OutputShaper.reduceOp(self.ser, a, axis)
1246
1247 attr = ts.TosaSerializerAttribute()
1248 attr.AxisAttribute(axis)
1249
1250 self.ser.addOperator(op, [a.name], result_tens.name, attr)
1251 return result_tens
1252
1253 def build_clamp(self, op, a):
1254 result_tens = OutputShaper.unaryOp(self.ser, a)
1255
1256 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01001257 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001258
1259 if a.dtype == DType.FLOAT:
1260 attr.ClampAttribute(0, 0, min(v), max(v))
1261 else:
1262 attr.ClampAttribute(min(v), max(v), 0, 0)
1263
1264 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1265 return result_tens
1266
1267 def build_leaky_relu(self, op, a):
1268 result_tens = OutputShaper.unaryOp(self.ser, a)
1269 attr = ts.TosaSerializerAttribute()
1270
1271 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
1272
1273 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1274 return result_tens
1275
1276 # Needs an additional type/input
1277 def build_prelu(self, op, a):
1278 result_tens = OutputShaper.unaryOp(self.ser, a)
1279
1280 self.ser.addOperator(op, [a.name], [result_tens.name])
1281 return result_tens
1282
1283 def build_relun(self, op, a):
1284 result_tens = OutputShaper.unaryOp(self.ser, a)
1285
1286 attr = ts.TosaSerializerAttribute()
1287
1288 if a.dtype == DType.FLOAT:
1289 attr.ReluNAttribute(0, self.getRandNumberDType(a.dtype))
1290 else:
1291 attr.ReluNAttribute(self.getRandNumberDType(a.dtype), 0)
1292
1293 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1294 return result_tens
1295
1296 def build_sigmoid(self, op, a):
1297 result_tens = OutputShaper.unaryOp(self.ser, a)
1298 self.ser.addOperator(op, [a.name], [result_tens.name])
1299 return result_tens
1300
1301 def build_tanh(self, op, a):
1302 result_tens = OutputShaper.unaryOp(self.ser, a)
1303 self.ser.addOperator(op, [a.name], [result_tens.name])
1304 return result_tens
1305
Matthew Haddon818ab902021-07-27 09:12:49 +01001306 def build_concat(self, op, *a):
1307 assert (type(a[-1]) == int)
1308
1309 # To store variable length list of input tensors we need to store axis along with it
1310 axis = a[-1]
1311 a = a[:-1]
1312
1313 result_tens = OutputShaper.concatOp(self.ser, axis, *a)
Eric Kunzee5e26762020-10-13 16:11:07 -07001314
1315 attr = ts.TosaSerializerAttribute()
1316 attr.AxisAttribute(axis)
1317
Matthew Haddon818ab902021-07-27 09:12:49 +01001318 input_tensor_names = []
1319 for tensor in a:
1320 input_tensor_names.append(tensor.name)
1321
1322 self.ser.addOperator(op, input_tensor_names, [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001323
1324 def build_pad(self, op, a, padding, qinfo):
1325 result_tens = OutputShaper.padOp(self.ser, a, padding)
1326
1327 # Need to turn the padding array into a TOSA tensor here.
1328 # This is one of the few tensor operands that does not get
1329 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08001330 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07001331
Kevin Cheng550ccc52021-03-03 11:21:43 -08001332 self.ser.addOperator(
1333 op, [a.name, padding_tens.name], [result_tens.name], None, qinfo
1334 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001335
1336 def build_reshape(self, op, a, newShape):
1337 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
1338
1339 attr = ts.TosaSerializerAttribute()
1340 attr.ReshapeAttribute(newShape)
1341
1342 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1343 return result_tens
1344
1345 def build_reverse(self, op, a, axis):
1346 result_tens = OutputShaper.unaryOp(self.ser, a)
1347
1348 attr = ts.TosaSerializerAttribute()
1349 attr.AxisAttribute(axis)
1350
1351 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1352 return result_tens
1353
1354 def build_transpose(self, op, a, perms):
1355 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
1356
Kevin Cheng550ccc52021-03-03 11:21:43 -08001357 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07001358
1359 self.ser.addOperator(op, [a.name, perms_tens.name], [result_tens.name])
1360 return result_tens
1361
1362 def build_slice(self, op, a, begin, size):
1363 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
1364
1365 attr = ts.TosaSerializerAttribute()
1366 attr.SliceAttribute(begin, size)
1367
1368 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1369 return result_tens
1370
1371 def build_tile(self, op, a, multiples):
1372 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
1373
1374 attr = ts.TosaSerializerAttribute()
1375 attr.TileAttribute(multiples)
1376
1377 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1378 return result_tens
1379
Kevin Cheng77d0f762020-11-24 10:26:32 -08001380 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07001381
1382 # Create a new indicies tensor
1383 # here with data that doesn't exceed the dimensions of the values tensor
1384
Kevin Cheng550ccc52021-03-03 11:21:43 -08001385 K = values.shape[1] # K
1386 W = self.randInt(
1387 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1388 ) # W
1389 indicies_arr = np.int32(
1390 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1391 ) # (N, W)
1392 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001393
Kevin Cheng77d0f762020-11-24 10:26:32 -08001394 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07001395
Kevin Cheng77d0f762020-11-24 10:26:32 -08001396 self.ser.addOperator(op, [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001397
1398 return result_tens
1399
Kevin Cheng77d0f762020-11-24 10:26:32 -08001400 def build_scatter(self, op, values_in, input):
1401
1402 # Create a new indicies tensor
1403 # here with data that doesn't exceed the dimensions of the values_in tensor
1404
Kevin Cheng550ccc52021-03-03 11:21:43 -08001405 K = values_in.shape[1] # K
1406 W = input.shape[1] # W
1407 indicies_arr = np.int32(
1408 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1409 ) # (N, W)
1410 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001411
1412 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
1413
Kevin Cheng550ccc52021-03-03 11:21:43 -08001414 self.ser.addOperator(
1415 op, [values_in.name, indicies.name, input.name], [result_tens.name]
1416 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001417
1418 return result_tens
1419
Kevin Cheng550ccc52021-03-03 11:21:43 -08001420 def build_resize(
1421 self,
1422 op,
1423 input,
1424 mode,
1425 stride,
1426 offset,
1427 shift,
1428 stride_fp,
1429 offset_fp,
1430 output_dims,
1431 input_dtype,
1432 output_dtype,
1433 ):
1434 result_tens = OutputShaper.resizeOp(
1435 self.ser,
1436 input,
1437 mode,
1438 stride,
1439 offset,
1440 shift,
1441 stride_fp,
1442 offset_fp,
1443 output_dims,
1444 input_dtype,
1445 output_dtype,
1446 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001447
1448 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001449
Kevin Cheng550ccc52021-03-03 11:21:43 -08001450 attr.ResizeAttribute(
1451 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
1452 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001453
1454 self.ser.addOperator(op, [input.name], [result_tens.name], attr)
1455 return result_tens
1456
1457 def build_identityn(self, op, val, val2):
1458
Kevin Cheng550ccc52021-03-03 11:21:43 -08001459 result_tens = OutputShaper.unaryOp(self.ser, val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001460 result_tens2 = OutputShaper.unaryOp(self.ser, val2)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001461 self.ser.addOperator(
1462 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1463 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001464 return result_tens
1465
1466 def build_placeholder(self, op, val):
1467 # Add an identity op to avoid warning in the reference model
1468 return self.build_unary(Op.IDENTITY, val)
1469
1470 # Type Conversion
1471 def build_cast(self, op, val, out_dtype):
1472 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1473 self.ser.addOperator(op, [val.name], [result_tens.name])
1474 return result_tens
1475
1476 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
1477 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1478
1479 if per_channel:
1480 nc = val.shape[-1]
1481 else:
1482 nc = 1
1483
1484 in_type_width = self.typeWidth(val.dtype)
1485 out_type_width = self.typeWidth(out_dtype)
1486
Kevin Cheng3a478572021-01-22 17:21:02 -08001487 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001488 input_zp = self.randInt(-128, 128)
1489 in_type_width = in_type_width + 1
1490 elif val.dtype == DType.UINT8:
1491 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07001492 in_type_width = in_type_width + 1
1493 else:
1494 input_zp = 0
1495
Kevin Cheng3a478572021-01-22 17:21:02 -08001496 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001497 output_zp = self.randInt(-128, 128)
1498 out_type_width = out_type_width + 1
1499 elif out_dtype == DType.UINT8:
1500 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07001501 out_type_width = out_type_width + 1
1502 else:
1503 output_zp = 0
1504
1505 # Calculate scale based on:
1506 # scale = a *(2^output_width)/(2^input_width))
1507
1508 a = np.float32(self.rng.random(size=[nc]))
1509 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1510
1511 if scale32:
1512 pass
1513 # Cap the scaling at 2^15 - 1 for scale16
1514 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1515 else:
1516 # Cap the scaling at 2^15 - 1 for scale16
1517 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1518
Kevin Cheng550ccc52021-03-03 11:21:43 -08001519 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001520
1521 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1522 shift_arr = np.int32(np.zeros(shape=[nc]))
1523
1524 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001525 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1526 scale_arr[i], scale32
1527 )
Kevin Chengaee1fac2020-11-11 13:54:06 -08001528 if shift_arr[i] < 2 or shift_arr[i] > 62:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001529 self.ser.setExpectedFailure(True, "OpRescale: invalid shift value")
Eric Kunzee5e26762020-10-13 16:11:07 -07001530
Kevin Cheng550ccc52021-03-03 11:21:43 -08001531 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07001532
1533 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001534 attr.RescaleAttribute(
1535 input_zp,
1536 output_zp,
1537 multiplier_arr,
1538 shift_arr,
1539 scale32,
1540 double_round,
1541 per_channel,
1542 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001543
1544 self.ser.addOperator(op, [val.name], [result_tens.name], attr)
1545 return result_tens
1546
1547 def build_cond_if_const(self, op, then_tens, else_tens, cond):
1548 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1549 # (except for the generated shap) and the condition. Build Then/Else blocks
1550 # and fill them with const nodes for the body.
1551
1552 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001553 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001554
1555 # Make then/else tensors
1556 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01001557 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1558 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001559
1560 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001561 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001562
1563 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001564 then_block = "THEN_BLOCK"
1565 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001566 attr = ts.TosaSerializerAttribute()
1567 attr.CondIfAttribute(then_block, else_block)
1568
1569 # Finally, build the op and the two blocks
1570 self.ser.addOperator(op, [cond_tens.name], [result_tens.name], attr)
1571
1572 self.ser.startBasicBlock(then_block)
1573 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001574 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001575 self.ser.addOutputTensor(then_tens)
1576
1577 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001578 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001579 self.ser.addOutputTensor(else_tens)
1580
1581 return result_tens
1582
1583 def build_cond_if_binary(self, op, a, b, cond):
1584 # For cond_if with a binary op in the then/else blocks, take a and b and
1585 # alternately add or subtract them based on the condition
1586
1587 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001588 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001589
Kevin Cheng550ccc52021-03-03 11:21:43 -08001590 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001591 self.ser.currBasicBlock.addOutput(result_tens.name)
1592
1593 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001594 then_block = "THEN_BLOCK"
1595 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001596 attr = ts.TosaSerializerAttribute()
1597 attr.CondIfAttribute(then_block, else_block)
1598
1599 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001600 self.ser.addOperator(
1601 op, [cond_tens.name, a.name, b.name], [result_tens.name], attr
1602 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001603
1604 self.ser.startBasicBlock(then_block)
1605 self.ser.addInputTensor(a)
1606 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001607 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001608 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
1609
1610 self.ser.startBasicBlock(else_block)
1611 self.ser.addInputTensor(a)
1612 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001613 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001614 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
1615
1616 return result_tens
1617
1618 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001619 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001620
Kevin Cheng550ccc52021-03-03 11:21:43 -08001621 cond_block = "COND_BLOCK"
1622 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001623
1624 attr = ts.TosaSerializerAttribute()
1625 attr.WhileLoopAttribute(cond_block, body_block)
1626
1627 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001628 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001629 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001630 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001631
1632 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001633 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1634 a_out = self.ser.addIntermediate(a.shape, a.dtype)
1635 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001636
1637 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001638 self.ser.addOperator(
1639 op,
1640 [iter.name, a.name, acc.name],
1641 [iter_out.name, a_out.name, acc_out.name],
1642 attr,
1643 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001644
1645 # COND block (input: iter, output: cond_tens )
1646 self.ser.startBasicBlock(cond_block)
1647 self.ser.addInputTensor(iter)
1648 self.ser.addInputTensor(a)
1649 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001650 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
1651 cond_tens = self.ser.addOutput([], DType.BOOL)
1652 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001653
1654 # BODY block (input: a, acc, iter, output: a, acc, iter)
1655 # Note that local intermediate tensors need to be declared here for the outputs
1656 self.ser.startBasicBlock(body_block)
1657 self.ser.addInputTensor(iter)
1658 self.ser.addInputTensor(a)
1659 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001660 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
1661 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1662 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001663 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
1664 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
1665 self.ser.addOutputTensor(iter_body_out)
1666 self.ser.addOutputTensor(a)
1667 self.ser.addOutputTensor(acc_body_out)
1668
1669 return acc_out
1670
Kevin Cheng550ccc52021-03-03 11:21:43 -08001671 def genOpTestList(
1672 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None
1673 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001674
1675 try:
1676 op = self.TOSA_OP_LIST[opName]
1677 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001678 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001679
1680 # Initialize a new random number generator
1681 self.rng = np.random.default_rng(self.random_seed)
1682
Kevin Cheng550ccc52021-03-03 11:21:43 -08001683 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001684
1685 # Generate the lists of arguments
Kevin Cheng550ccc52021-03-03 11:21:43 -08001686 rmin, rmax = op["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001687
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001688 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
1689 default_test_rank_range = range(1, 5)
1690
Eric Kunzee5e26762020-10-13 16:11:07 -07001691 # Test list consists of a tuple of:
1692 # (opName, testNameStr, dtype, shapeList, argumentsList)
1693 testList = []
1694
1695 if not shapeFilter:
1696 shapeFilter = [None]
1697
1698 for r in range(rmin, rmax + 1):
1699
1700 # Filter out the rank?
1701 if rankFilter is not None and r not in rankFilter:
1702 continue
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001703 if rankFilter is None and shapeFilter[0] is None and r not in default_test_rank_range:
1704 continue
Eric Kunzee5e26762020-10-13 16:11:07 -07001705
Kevin Cheng550ccc52021-03-03 11:21:43 -08001706 for t in op["types"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001707
1708 # Filter tests based on dtype?
1709 if dtypeFilter is not None:
Les Bell30e46802021-07-23 09:43:31 +01001710 if not (t in dtypeFilter or (isinstance(t, list) and t[0] in dtypeFilter)):
Eric Kunzee5e26762020-10-13 16:11:07 -07001711 continue
1712
1713 # Create the placeholder and const tensors
1714 for shape in shapeFilter:
1715 # A None shape chooses a random shape of a given rank
1716
1717 # Filter out by rank
1718 if shape is not None and len(shape) != r:
1719 continue
1720
1721 self.setTargetShape(shape)
1722 shapeList = tgen_fcn(self, op, r)
1723
1724 shapeStr = self.shapeStr(shapeList[0])
1725 typeStr = self.typeStr(t)
1726
1727 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
1728 argList = []
1729 if agen_fcn:
1730 argList = agen_fcn(self, opName, shapeList, t)
1731 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001732 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07001733
1734 for argStr, args in argList:
1735 if argStr:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001736 testStr = "{}_{}_{}_{}".format(
1737 opName, shapeStr, typeStr, argStr
1738 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001739 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001740 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001741
1742 testList.append((opName, testStr, t, shapeList, args))
1743
1744 return testList
1745
Kevin Cheng989cb052021-04-28 16:29:44 -07001746 def serializeTest(self, opName, testStr, dtype_or_dtypeList, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07001747 try:
1748 op = self.TOSA_OP_LIST[opName]
1749 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001750 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001751
1752 # Create a serializer
1753 self.createSerializer(opName, testStr)
1754
Kevin Cheng550ccc52021-03-03 11:21:43 -08001755 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
1756 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07001757 num_operands = pCount + cCount
1758
1759 if isinstance(dtype_or_dtypeList, list):
1760 dtypeList = dtype_or_dtypeList
Matthew Haddon818ab902021-07-27 09:12:49 +01001761 elif op['op'] == Op.CONCAT:
1762 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07001763 else:
1764 dtypeList = [dtype_or_dtypeList] * (num_operands)
1765
Matthew Haddon818ab902021-07-27 09:12:49 +01001766 if op['op'] != Op.CONCAT:
1767 assert (
1768 len(shapeList) == num_operands
1769 ), "shapeList length {} must match number of operands {}".format(
1770 len(shapeList), num_operands
1771 )
1772 assert (
1773 len(dtypeList) == num_operands
1774 ), "dtypeList length {} must match number of operands {}".format(
1775 len(dtypeList), num_operands
1776 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001777
1778 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001779 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001780 except KeyError:
1781 qgen = None
1782
1783 # Build the random tensor operands and the test
1784 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08001785
1786 # If test is ArithmeticRightShift, force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001787 if op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
1788 assert (
1789 pCount == 2 and cCount == 0
1790 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08001791
1792 placeholders = []
1793 for idx, shape in enumerate(shapeList[:]):
1794 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07001795 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001796 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001797 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001798 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001799 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001800 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
1801 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001802 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08001803 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001804 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07001805 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001806
1807 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01001808 elif op["op"] == Op.SELECT:
1809 # Set datatype of condition tensor to boolean
1810 dtypeList[0] = DType.BOOL
1811 tens.extend(
1812 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
1813 )
1814 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001815 elif op["op"] == Op.DIV:
1816 assert (
1817 pCount == 2 and cCount == 0
1818 ), "Op.Div must have 2 placeholders, 0 consts"
1819
1820 placeholders = []
1821
1822 # Two invalid cases for Op.DIV:
1823 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07001824 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001825 while True:
1826 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
1827 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
1828
1829 if (divisor_arr == 0).any():
1830 continue
1831
Kevin Cheng47315e12021-05-13 17:41:28 -07001832 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001833 continue
1834
1835 break
1836
1837 placeholders.append(
1838 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1839 )
1840 placeholders.append(
1841 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1842 )
1843
1844 tens.extend(placeholders)
1845 elif op["op"] == Op.MUL:
1846 assert (
1847 pCount == 2 and cCount == 0
1848 ), "Op.MUL must have 2 placeholders, 0 consts"
1849
1850 if dtypeList[0] == DType.FLOAT:
1851 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
1852 else:
1853 placeholders = []
1854
1855 # Make sure multiply result in int32 range
1856 shift = testArgs[0]
1857 if dtypeList[0] == DType.INT8:
1858 num_bits = 8
1859 elif dtypeList[0] == DType.INT16:
1860 num_bits = 16
1861 elif dtypeList[0] == DType.INT32:
1862 num_bits = 32
1863 else:
1864 raise Exception("OpMul: invalid input dtype")
1865
1866 for idx, shape in enumerate(shapeList[:]):
1867 low = -(2 ** (num_bits - 1))
1868 high = (2 ** (num_bits - 1)) - 1
1869
1870 a_arr = np.int32(
1871 self.rng.integers(low=low, high=high, size=shapeList[0])
1872 )
1873 b_arr = np.int32(
1874 self.rng.integers(low=low, high=high, size=shapeList[1])
1875 )
1876
1877 i = 0
1878 while True:
1879
1880 a_arr_64 = a_arr.astype(np.int64)
1881 b_arr_64 = b_arr.astype(np.int64)
1882
1883 if shift > 0:
1884 rounding = 1 << (shift - 1)
1885 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
1886 else:
1887 result_arr = a_arr_64 * b_arr_64
1888
1889 if (result_arr > -(2 ** 31)).all() and (
1890 result_arr <= ((2 ** 31) - 1)
1891 ).all():
1892 break
1893
1894 i = i + 1
1895 a_arr = a_arr // 2
1896 b_arr = b_arr // 2
1897
1898 placeholders.append(
1899 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1900 )
1901 placeholders.append(
1902 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1903 )
1904
1905 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01001906 elif op["op"] == Op.CONCAT:
1907 count = len(shapeList) - self.args.num_const_inputs_concat
1908 if count < 1:
1909 count = 1
1910 if self.args.num_const_inputs_concat == 0:
1911 count = len(shapeList)
1912
1913 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
1914 tens.extend(
1915 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
1916 )
1917 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001918 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001919 tens.extend(
1920 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
1921 )
1922 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001923
1924 if qgen is not None:
Les Bell30e46802021-07-23 09:43:31 +01001925 qinfo = qgen(self, op, dtype_or_dtypeList)
Eric Kunzee5e26762020-10-13 16:11:07 -07001926 else:
1927 qinfo = None
1928
1929 try:
1930 if qinfo is not None:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001931 resultName = build_fcn(self, op["op"], *tens, *testArgs, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07001932 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001933 resultName = build_fcn(self, op["op"], *tens, *testArgs)
Eric Kunzee5e26762020-10-13 16:11:07 -07001934 except TypeError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001935 print(
1936 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
1937 build_fcn, tens, testArgs
1938 )
1939 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001940 raise e
1941
1942 # Save the serialized test
Kevin Cheng550ccc52021-03-03 11:21:43 -08001943 self.serialize("test")
Eric Kunzee5e26762020-10-13 16:11:07 -07001944
1945 def createDynamicOpLists(self):
1946
1947 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng550ccc52021-03-03 11:21:43 -08001948 KERNELS = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07001949
1950 for k in KERNELS:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001951 testName = "conv2d_{}x{}".format(k[0], k[1])
1952 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
1953 self.TOSA_OP_LIST[testName]["filter"] = k
1954 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001955
Kevin Cheng550ccc52021-03-03 11:21:43 -08001956 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
1957 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1958 "depthwise_conv2d_TEMPLATE"
1959 ].copy()
1960 self.TOSA_OP_LIST[testName]["filter"] = k
1961 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001962
Kevin Cheng550ccc52021-03-03 11:21:43 -08001963 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
1964 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1965 "transpose_conv2d_TEMPLATE"
1966 ].copy()
1967 self.TOSA_OP_LIST[testName]["filter"] = k
1968 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001969
1970 # Delete any templates after having created any dynamic ops
1971 # This is a two-pass operation because it's bad practice to delete
1972 # keys from dictionaries while iterating
1973 keyList = []
1974 for k in self.TOSA_OP_LIST:
1975 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001976 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07001977 keyList.append(k)
1978 continue
1979 except KeyError:
1980 pass
1981
1982 for k in keyList:
1983 del self.TOSA_OP_LIST[k]
1984
1985 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001986 """Fill in default fields for ops if they aren't already specified.
1987 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07001988 for op in self.TOSA_OP_LIST:
1989
1990 # Required fields
1991 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001992 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001993 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001994 raise Exception(
1995 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
1996 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001997
1998 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001999 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002000 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002001 raise Exception(
2002 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2003 op
2004 )
2005 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002006
2007 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002008 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002009 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002010 raise Exception(
2011 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2012 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002013
2014 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002015 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002016 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002017 raise Exception(
2018 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2019 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002020
2021 # Put in default rank range, if missing
2022 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002023 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002024 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002025 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002026
2027 # Tensor operator list
2028 # 'op': op name
2029 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002030 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2031 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002032 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2033 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08002034 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002035
Kevin Cheng550ccc52021-03-03 11:21:43 -08002036 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
2037 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002038
Kevin Cheng550ccc52021-03-03 11:21:43 -08002039 TYPE_BOOL = [DType.BOOL]
2040 TYPE_FI32 = [DType.FLOAT, DType.INT32]
2041 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
2042 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002043
Kevin Cheng550ccc52021-03-03 11:21:43 -08002044 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002045
Kevin Cheng989cb052021-04-28 16:29:44 -07002046 TYPE_CONV2D = [
Kevin Chenga9017402021-07-28 17:19:23 -07002047 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002048 [DType.INT8, DType.INT8, DType.INT32],
2049 [DType.INT16, DType.INT8, DType.INT48],
2050 DType.FLOAT,
2051 ]
2052
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002053 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002054
2055 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002056 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002057 "argmax": {
2058 "op": Op.ARGMAX,
2059 "operands": (1, 0),
2060 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2061 "types": TYPE_NARROW_INT_FP,
2062 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002063 "avg_pool2d": {
2064 "op": Op.AVG_POOL2D,
2065 "operands": (1, 0),
2066 "rank": (4, 4),
2067 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
2068 "qgen": TosaQuantGen.qgUnary,
2069 "types": TYPE_NARROW_INT_FP,
2070 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002071 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002072 "conv2d_TEMPLATE": {
2073 "op": Op.CONV2D,
2074 "operands": (1, 2),
2075 "rank": (4, 4),
2076 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D),
2077 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002078 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002079 "template": True,
2080 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002081 # Conv3d TBD
Eric Kunzee5e26762020-10-13 16:11:07 -07002082 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002083 "depthwise_conv2d_TEMPLATE": {
2084 "op": Op.DEPTHWISE_CONV2D,
2085 "operands": (1, 2),
2086 "filter": [1, 1],
2087 "rank": (4, 4),
2088 "build_fcn": (
2089 build_depthwise_conv2d,
2090 TosaTensorGen.tgDepthwiseConv2D,
2091 TosaArgGen.agConv2D,
2092 ),
2093 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002094 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002095 "template": True,
2096 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002097 "fully_connected": {
2098 "op": Op.FULLY_CONNECTED,
2099 "operands": (1, 2),
2100 "rank": (2, 2),
2101 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
2102 "qgen": TosaQuantGen.qgConv,
2103 "types": TYPE_CONV2D,
2104 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002105 "matmul": {
2106 "op": Op.MATMUL,
2107 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002108 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08002109 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
2110 "qgen": TosaQuantGen.qgMatmul,
2111 "types": TYPE_NARROW_INT_FP,
2112 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002113 "max_pool2d": {
2114 "op": Op.MAX_POOL2D,
2115 "operands": (1, 0),
2116 "rank": (4, 4),
2117 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
2118 "types": TYPE_NARROW_INT_FP,
2119 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002120 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002121 "transpose_conv2d_TEMPLATE": {
2122 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002123 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002124 "rank": (4, 4),
2125 "build_fcn": (
2126 build_transpose_conv2d,
2127 TosaTensorGen.tgTransposeConv2D,
2128 TosaArgGen.agTransposeConv2D,
2129 ),
2130 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002131 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002132 "template": True,
2133 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002134 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002135 "clamp": {
2136 "op": Op.CLAMP,
2137 "operands": (1, 0),
2138 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
2139 "types": TYPE_NARROW_INT_FP,
2140 },
2141 "relun": {
2142 "op": Op.RELUN,
2143 "operands": (1, 0),
2144 "build_fcn": (build_relun, TosaTensorGen.tgBasic, None),
2145 "types": TYPE_FI32,
2146 },
2147 "sigmoid": {
2148 "op": Op.SIGMOID,
2149 "operands": (1, 0),
2150 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
2151 "types": TYPE_FP,
2152 },
2153 "tanh": {
2154 "op": Op.TANH,
2155 "operands": (1, 0),
2156 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
2157 "types": TYPE_FP,
2158 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002159 # Elementwise Binary Operators
2160 "add": {
2161 "op": Op.ADD,
2162 "operands": (2, 0),
2163 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2164 "types": TYPE_FI32,
2165 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002166 "arithmetic_right_shift": {
2167 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2168 "operands": (2, 0),
2169 "build_fcn": (
2170 build_arithmetic_right_shift,
2171 TosaTensorGen.tgBroadcastFuzz,
2172 TosaArgGen.agArithmeticRightShift,
2173 ),
2174 "types": TYPE_INT,
2175 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002176 "bitwise_and": {
2177 "op": Op.BITWISE_AND,
2178 "operands": (2, 0),
2179 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2180 "types": TYPE_INT,
2181 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002182 "bitwise_or": {
2183 "op": Op.BITWISE_OR,
2184 "operands": (2, 0),
2185 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2186 "types": TYPE_INT,
2187 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002188 "bitwise_xor": {
2189 "op": Op.BITWISE_XOR,
2190 "operands": (2, 0),
2191 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2192 "types": TYPE_INT,
2193 },
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002194 "div": {
2195 "op": Op.DIV,
2196 "operands": (2, 0),
2197 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2198 "types": [DType.INT32],
2199 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002200 "logical_and": {
2201 "op": Op.LOGICAL_AND,
2202 "operands": (2, 0),
2203 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2204 "types": TYPE_BOOL,
2205 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002206 "logical_left_shift": {
2207 "op": Op.LOGICAL_LEFT_SHIFT,
2208 "operands": (2, 0),
2209 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2210 "types": TYPE_INT,
2211 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002212 "logical_right_shift": {
2213 "op": Op.LOGICAL_RIGHT_SHIFT,
2214 "operands": (2, 0),
2215 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2216 "types": TYPE_INT,
2217 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002218 "logical_or": {
2219 "op": Op.LOGICAL_OR,
2220 "operands": (2, 0),
2221 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2222 "types": TYPE_BOOL,
2223 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002224 "logical_xor": {
2225 "op": Op.LOGICAL_XOR,
2226 "operands": (2, 0),
2227 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2228 "types": TYPE_BOOL,
2229 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002230 "maximum": {
2231 "op": Op.MAXIMUM,
2232 "operands": (2, 0),
2233 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2234 "types": TYPE_FI32,
2235 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002236 "minimum": {
2237 "op": Op.MINIMUM,
2238 "operands": (2, 0),
2239 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2240 "types": TYPE_FI32,
2241 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002242 "mul": {
2243 "op": Op.MUL,
2244 "operands": (2, 0),
2245 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
2246 "types": TYPE_INT_FP,
2247 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002248 "pow": {
2249 "op": Op.POW,
2250 "operands": (2, 0),
2251 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
2252 "types": TYPE_FP,
2253 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002254 "sub": {
2255 "op": Op.SUB,
2256 "operands": (2, 0),
2257 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2258 "types": TYPE_FI32,
2259 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002260 "table": {
2261 "op": Op.TABLE,
2262 # Use the automatic generation functions to create the input array
2263 # but create the table tensor in the build function, as it may be
2264 # a different type from the input
2265 "operands": (1, 0),
2266 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002267 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08002268 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002269 # Elementwise Unary operators
2270 "abs": {
2271 "op": Op.ABS,
2272 "operands": (1, 0),
2273 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2274 "types": TYPE_FI32,
2275 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002276 "bitwise_not": {
2277 "op": Op.BITWISE_NOT,
2278 "operands": (1, 0),
2279 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2280 "types": TYPE_INT,
2281 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002282 "ceil": {
2283 "op": Op.CEIL,
2284 "operands": (1, 0),
2285 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2286 "types": TYPE_FP,
2287 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002288 "clz": {
2289 "op": Op.CLZ,
2290 "operands": (1, 0),
2291 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2292 "types": [DType.INT32],
2293 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002294 "exp": {
2295 "op": Op.EXP,
2296 "operands": (1, 0),
2297 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2298 "types": TYPE_FP,
2299 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002300 "floor": {
2301 "op": Op.FLOOR,
2302 "operands": (1, 0),
2303 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2304 "types": TYPE_FP,
2305 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002306 "log": {
2307 "op": Op.LOG,
2308 "operands": (1, 0),
2309 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2310 "types": TYPE_FP,
2311 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002312 "logical_not": {
2313 "op": Op.LOGICAL_NOT,
2314 "operands": (1, 0),
2315 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2316 "types": TYPE_BOOL,
2317 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002318 "negate": {
2319 "op": Op.NEGATE,
2320 "operands": (1, 0),
2321 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2322 "qgen": TosaQuantGen.qgUnary,
2323 "types": TYPE_INT_FP,
2324 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002325 "reciprocal": {
2326 "op": Op.RECIPROCAL,
2327 "operands": (1, 0),
2328 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2329 "types": TYPE_FP,
2330 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002331 "rsqrt": {
2332 "op": Op.RSQRT,
2333 "operands": (1, 0),
2334 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2335 "types": TYPE_FP,
2336 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002337 # Elementwise Ternary operators
2338 "select": {
2339 "op": Op.SELECT,
2340 "operands": (3, 0),
2341 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
2342 "types": TYPE_FIB,
2343 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002344 # Comparison operators
2345 "equal": {
2346 "op": Op.EQUAL,
2347 "operands": (2, 0),
2348 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2349 "types": TYPE_FI32,
2350 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002351 "greater_equal": {
2352 "op": Op.GREATER_EQUAL,
2353 "operands": (2, 0),
2354 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2355 "types": TYPE_FI32,
2356 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002357 "greater": {
2358 "op": Op.GREATER,
2359 "operands": (2, 0),
2360 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2361 "types": TYPE_FI32,
2362 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002363 # Reduction operators
2364 "reduce_all": {
2365 "op": Op.REDUCE_ALL,
2366 "operands": (1, 0),
2367 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2368 "types": TYPE_BOOL,
2369 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002370 "reduce_any": {
2371 "op": Op.REDUCE_ANY,
2372 "operands": (1, 0),
2373 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2374 "types": TYPE_BOOL,
2375 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002376 "reduce_max": {
2377 "op": Op.REDUCE_MAX,
2378 "operands": (1, 0),
2379 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2380 "types": TYPE_INT_FP,
2381 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002382 "reduce_min": {
2383 "op": Op.REDUCE_MAX,
2384 "operands": (1, 0),
2385 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2386 "types": TYPE_INT_FP,
2387 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002388 "reduce_product": {
2389 "op": Op.REDUCE_PRODUCT,
2390 "operands": (1, 0),
2391 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2392 "types": TYPE_FP,
2393 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002394 "reduce_sum": {
2395 "op": Op.REDUCE_SUM,
2396 "operands": (1, 0),
2397 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2398 "types": TYPE_FI32,
2399 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002400 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002401 "concat": {
2402 "op": Op.CONCAT,
2403 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01002404 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002405 "types": TYPE_FIB,
2406 },
2407 "pad": {
2408 "op": Op.PAD,
2409 "operands": (1, 0),
2410 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
2411 "qgen": TosaQuantGen.qgPad,
2412 "types": TYPE_FIB,
2413 },
2414 "reshape": {
2415 "op": Op.RESHAPE,
2416 "operands": (1, 0),
2417 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
2418 "types": TYPE_FIB,
2419 },
2420 "reverse": {
2421 "op": Op.REVERSE,
2422 "operands": (1, 0),
2423 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2424 "types": TYPE_FIB,
2425 },
2426 "slice": {
2427 "op": Op.SLICE,
2428 "operands": (1, 0),
2429 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
2430 "types": TYPE_FIB,
2431 },
2432 "tile": {
2433 "op": Op.TILE,
2434 "operands": (1, 0),
2435 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
2436 "types": TYPE_FIB,
2437 },
2438 "transpose": {
2439 "op": Op.TRANSPOSE,
2440 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01002441 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002442 "build_fcn": (
2443 build_transpose,
2444 TosaTensorGen.tgBasic,
2445 TosaArgGen.agTranspose,
2446 ),
2447 "types": TYPE_FIB,
2448 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002449 # Data nodes
2450 "const": {
2451 "op": Op.CONST,
2452 "operands": (1, 0),
2453 "build_fcn": (build_placeholder, TosaTensorGen.tgBasic, None),
2454 "types": TYPE_FIB,
2455 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002456 "identity": {
2457 "op": Op.IDENTITY,
2458 "operands": (1, 0),
2459 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2460 "types": TYPE_FIB,
2461 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002462 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08002463 "gather": {
2464 "op": Op.GATHER,
2465 # Only specify 'values' tensor here. 'indices' is generated in op building stage
2466 "operands": (1, 0),
2467 "rank": (3, 3),
2468 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
2469 "types": TYPE_INT_FP,
2470 },
2471 "scatter": {
2472 "op": Op.SCATTER,
2473 # Only specify 'values_in' tensor here.
2474 #'indices' and 'input' are generated in op building stage
2475 "operands": (2, 0),
2476 "rank": (3, 3),
2477 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
2478 "types": TYPE_INT_FP,
2479 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002480 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08002481 "resize": {
2482 "op": Op.RESIZE,
2483 "operands": (1, 0),
2484 "rank": (4, 4),
2485 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
2486 "types": [DType.INT8, DType.INT16, DType.FLOAT],
2487 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002488 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08002489 "cast": {
2490 "op": Op.CAST,
2491 "operands": (1, 0),
2492 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
2493 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
2494 },
2495 "rescale": {
2496 "op": Op.RESCALE,
2497 "operands": (1, 0),
2498 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002499 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002500 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002501 # Custom
2502 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08002503 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07002504 # Two varients of cond_if, one that generates one of two constant tensors (no
2505 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
2506 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002507 "cond_if_const": {
2508 "op": Op.COND_IF,
2509 "operands": (0, 2),
2510 "build_fcn": (
2511 build_cond_if_const,
2512 TosaTensorGen.tgBasic,
2513 TosaArgGen.agCondIf,
2514 ),
2515 "types": [DType.BOOL],
2516 },
2517 "cond_if_binary": {
2518 "op": Op.COND_IF,
2519 "operands": (2, 0),
2520 "build_fcn": (
2521 build_cond_if_binary,
2522 TosaTensorGen.tgBasic,
2523 TosaArgGen.agCondIf,
2524 ),
2525 "types": TYPE_FI32,
2526 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002527 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002528 "while_loop": {
2529 "op": Op.WHILE_LOOP,
2530 "operands": (0, 1),
2531 "build_fcn": (
2532 build_while_loop,
2533 TosaTensorGen.tgBasic,
2534 TosaArgGen.agWhileLoop,
2535 ),
2536 "types": [DType.INT32],
2537 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002538 }
2539
Kevin Cheng550ccc52021-03-03 11:21:43 -08002540
Eric Kunzee5e26762020-10-13 16:11:07 -07002541class OutputShaper:
2542 # Methods in this class compute the expected output shape and datatype
2543 # for common classes of operations
2544 def __init__(self):
2545 pass
2546
2547 # These methods return arguments that can be used for
2548 # creating a new output tensor
2549 @staticmethod
2550 def binaryBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002551 assert len(a.shape) == len(b.shape)
2552 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002553
2554 shape = []
2555 for i in range(len(a.shape)):
2556 if a.shape[i] == 1:
2557 shape.append(b.shape[i])
2558 else:
2559 shape.append(a.shape[i])
2560
Kevin Cheng550ccc52021-03-03 11:21:43 -08002561 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002562
2563 @staticmethod
2564 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002565 assert len(a.shape) == len(b.shape)
2566 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002567
2568 shape = []
2569 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002570 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07002571 shape.append(a.shape[i])
2572
Kevin Cheng550ccc52021-03-03 11:21:43 -08002573 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002574
2575 @staticmethod
2576 def unaryOp(ser, a):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002577 return ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002578
2579 @staticmethod
2580 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002581 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
2582 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002583
2584 shape = []
2585 for i in range(len(a.shape)):
2586 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
2587
Kevin Cheng550ccc52021-03-03 11:21:43 -08002588 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002589
2590 @staticmethod
2591 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002592 assert len(a.shape) == len(b.shape)
2593 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002594
2595 # Do broadcast
2596 shape = []
2597 for i in range(len(a.shape)):
2598 if a.shape[i] == 1:
2599 shape.append(b.shape[i])
2600 else:
2601 shape.append(a.shape[i])
2602
2603 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08002604 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07002605
2606 @staticmethod
2607 def reduceOp(ser, a, axis):
2608
2609 shape = a.shape.copy()
2610
2611 shape[axis] = 1
2612
Kevin Cheng550ccc52021-03-03 11:21:43 -08002613 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002614
2615 @staticmethod
2616 def argmaxOp(ser, a, axis):
2617 shape = a.shape.copy()
2618 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002619 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002620
2621 @staticmethod
2622 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
2623
2624 # IFM: NHWC
2625 # Filter: OHWI
2626 # OFM: NHWC
2627
2628 if len(padding) == 2:
2629 # Expand padding to 4 parameters in the case of transpose_conv2d
2630 # From H,W to T,B,L,R
2631 padding = [padding[0], padding[0], padding[1], padding[1]]
2632
Kevin Cheng550ccc52021-03-03 11:21:43 -08002633 h = (
2634 ifm.shape[1]
2635 - filter.shape[1]
2636 - (filter.shape[1] - 1) * (dilations[0] - 1)
2637 + padding[0]
2638 + padding[1]
2639 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002640
Kevin Cheng550ccc52021-03-03 11:21:43 -08002641 w = (
2642 ifm.shape[2]
2643 - filter.shape[2]
2644 - (filter.shape[2] - 1) * (dilations[1] - 1)
2645 + padding[2]
2646 + padding[3]
2647 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002648
2649 if h <= 0 or w <= 0:
2650 # Invalid test parameters?
2651 h = 0
2652 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002653 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002654
2655 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
2656
Kevin Cheng3a478572021-01-22 17:21:02 -08002657 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002658 out_dtype = DType.INT32
2659 elif ifm.dtype == DType.INT16:
2660 out_dtype = DType.INT48
2661 elif ifm.dtype == DType.FLOAT:
2662 out_dtype = DType.FLOAT
2663 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002664 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002665
Kevin Cheng550ccc52021-03-03 11:21:43 -08002666 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002667
2668 @staticmethod
2669 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
2670 # IFM: NHWC
2671 # Filter: HWCM
2672 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08002673 h = (
2674 ifm.shape[1]
2675 - filter.shape[0]
2676 - (filter.shape[0] - 1) * (dilations[0] - 1)
2677 + padding[0]
2678 + padding[1]
2679 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002680
Kevin Cheng550ccc52021-03-03 11:21:43 -08002681 w = (
2682 ifm.shape[2]
2683 - filter.shape[1]
2684 - (filter.shape[1] - 1) * (dilations[1] - 1)
2685 + padding[2]
2686 + padding[3]
2687 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002688
2689 if h <= 0 or w <= 0:
2690 # Invalid test parameters?
2691 h = 0
2692 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002693 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002694
2695 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
2696
Kevin Cheng3a478572021-01-22 17:21:02 -08002697 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002698 out_dtype = DType.INT32
2699 elif ifm.dtype == DType.INT16:
2700 out_dtype = DType.INT48
2701 elif ifm.dtype == DType.FLOAT:
2702 out_dtype = DType.FLOAT
2703 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002704 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002705
Kevin Cheng550ccc52021-03-03 11:21:43 -08002706 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002707
2708 @staticmethod
2709 def pool2dOp(ser, ifm, kernel, stride, pad):
2710 # input: NHWC
2711 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
2712 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
2713
2714 if h <= 0 or w <= 0:
2715 # Invalid test parameters?
2716 h = 0
2717 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002718 ser.setExpectedFailure(True, "Invalid combination of pooling parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002719
2720 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002721 return ser.addOutput(ofm_shape, ifm.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002722
2723 @staticmethod
2724 def fullyConnectedOp(ser, input, filter):
2725 # input: N, IC
2726 # filter: OC, IC
2727 # output: N, OC
2728
2729 output_shape = [input.shape[0], filter.shape[0]]
2730
Kevin Cheng3a478572021-01-22 17:21:02 -08002731 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002732 out_dtype = DType.INT32
2733 elif input.dtype == DType.INT16:
2734 out_dtype = DType.INT48
2735 elif input.dtype == DType.FLOAT:
2736 out_dtype = DType.FLOAT
2737 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002738 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002739
Kevin Cheng550ccc52021-03-03 11:21:43 -08002740 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002741
2742 @staticmethod
2743 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07002744 # a: N, H, C
2745 # b: N, C, W
2746 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07002747
Kevin Cheng2d60f002021-06-09 14:18:32 -07002748 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002749
Kevin Cheng3a478572021-01-22 17:21:02 -08002750 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002751 out_dtype = DType.INT32
2752 elif a.dtype == DType.INT16:
2753 out_dtype = DType.INT48
2754 elif a.dtype == DType.FLOAT:
2755 out_dtype = DType.FLOAT
2756 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002757 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002758
Kevin Cheng550ccc52021-03-03 11:21:43 -08002759 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002760
2761 @staticmethod
Matthew Haddon818ab902021-07-27 09:12:49 +01002762 def concatOp(ser, axis, *a):
2763 input1 = a[0]
2764 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07002765
Matthew Haddon818ab902021-07-27 09:12:49 +01002766 output_shape = input1.shape.copy()
Eric Kunzee5e26762020-10-13 16:11:07 -07002767
Matthew Haddon818ab902021-07-27 09:12:49 +01002768 output_shape[axis] = input1.shape[axis]
2769
2770 for tensor in remaining_inputs:
2771 output_shape[axis] += tensor.shape[axis]
2772
2773 return ser.addOutput(output_shape, input1.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002774
2775 @staticmethod
2776 def padOp(ser, a, padding):
2777
2778 output_shape = a.shape.copy()
2779
2780 for i in range(len(output_shape)):
2781 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
2782
Kevin Cheng550ccc52021-03-03 11:21:43 -08002783 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002784
2785 @staticmethod
2786 def reshapeOp(ser, a, shape):
2787 output_shape = shape.copy()
2788
2789 totalElements = 1
2790 for i in a.shape:
2791 totalElements *= i
2792
2793 # If there are any -1 elements, figure out what that dimension must be
2794 totalOutputElements = 1
2795 for i in output_shape:
2796 if i != -1:
2797 totalOutputElements *= i
2798
2799 # And fill it in
2800 for i in range(len(output_shape)):
2801 if output_shape[i] == -1:
2802 output_shape[i] = totalElements // totalOutputElements
2803
Kevin Cheng550ccc52021-03-03 11:21:43 -08002804 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002805
2806 @staticmethod
2807 def sliceOp(ser, a, begin, size):
2808
2809 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002810 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002811
2812 @staticmethod
2813 def tileOp(ser, a, multiples):
2814
2815 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002816 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002817
2818 for i in range(len(output_shape)):
2819 output_shape[i] = a.shape[i] * multiples[i]
2820
Kevin Cheng550ccc52021-03-03 11:21:43 -08002821 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002822
2823 @staticmethod
2824 def transposeOp(ser, a, perms):
2825 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002826 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002827
2828 for i in range(len(output_shape)):
2829 output_shape[i] = a.shape[perms[i]]
2830
Kevin Cheng550ccc52021-03-03 11:21:43 -08002831 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002832
2833 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08002834 def gatherOp(ser, values, indices):
2835 assert len(values.shape) == 3
2836 assert len(indices.shape) == 2
2837 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07002838
Kevin Cheng77d0f762020-11-24 10:26:32 -08002839 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
2840
Kevin Cheng550ccc52021-03-03 11:21:43 -08002841 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002842
2843 @staticmethod
2844 def scatterOp(ser, values_in, indices, input):
2845 assert len(values_in.shape) == 3
2846 assert len(indices.shape) == 2
2847 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08002848 assert values_in.shape[0] == indices.shape[0] # N
2849 assert input.shape[1] == indices.shape[1] # W
2850 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08002851
2852 output_shape = values_in.shape
2853
Kevin Cheng550ccc52021-03-03 11:21:43 -08002854 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002855
2856 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002857 def tableOp(ser, input, table_dtype):
2858 # Same shape as the input, but dtype dependent on table dtype
2859 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
2860 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
2861 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002862
2863 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08002864 def resizeOp(
2865 ser,
2866 input,
2867 mode,
2868 stride,
2869 offset,
2870 shift,
2871 stride_fp,
2872 offset_fp,
2873 output_dims,
2874 input_dtype,
2875 output_dtype,
2876 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002877
2878 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
2879
Kevin Cheng77d0f762020-11-24 10:26:32 -08002880 if input_dtype == DType.FLOAT:
2881 if stride_fp[0] <= 0 or stride_fp[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002882 ser.setExpectedFailure(True, "Negative or zero stride")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002883 else:
2884 if stride[0] <= 0 or stride[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002885 ser.setExpectedFailure(True, "Negative or zero stride")
Eric Kunzee5e26762020-10-13 16:11:07 -07002886
Kevin Chengaee1fac2020-11-11 13:54:06 -08002887 if mode == ResizeMode.BILINEAR:
2888 if input_dtype == DType.INT8:
2889 if output_dtype != DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002890 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002891 elif input_dtype == DType.INT16:
2892 if output_dtype != DType.INT48:
Kevin Cheng989cb052021-04-28 16:29:44 -07002893 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002894 elif input_dtype == DType.FLOAT:
2895 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002896 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002897 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002898 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002899
2900 elif mode == ResizeMode.NEAREST:
2901 if input_dtype == DType.INT8:
2902 if output_dtype != DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002903 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002904 elif input_dtype == DType.INT16:
2905 if output_dtype != DType.INT16:
Kevin Cheng989cb052021-04-28 16:29:44 -07002906 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002907 elif input_dtype == DType.FLOAT:
2908 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002909 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002910 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002911 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002912
2913 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002914 ser.setExpectedFailure(true, "Invalid resize mode")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002915
Kevin Cheng550ccc52021-03-03 11:21:43 -08002916 return ser.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002917
2918 @staticmethod
2919 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002920 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002921
2922 @staticmethod
2923 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08002924 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002925 out_dtype = DType.INT32
2926 elif ifm.dtype == DType.INT16:
2927 out_dtype = DType.INT48
2928 elif ifm.dtype == DType.FLOAT:
2929 out_dtype = DType.FLOAT
2930 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002931 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002932
2933 if output_shape[1] <= 0 or output_shape[2] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002934 ser.setExpectedFailure(True, "Negative output shape")
Eric Kunzee5e26762020-10-13 16:11:07 -07002935
Kevin Cheng550ccc52021-03-03 11:21:43 -08002936 return ser.addOutput(output_shape, out_dtype)