blob: a3c6b05dfebca464a12ea5d0a35e89eca0f11abe [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
35
Kevin Cheng550ccc52021-03-03 11:21:43 -080036# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
37parent_dir = os.path.dirname(os.path.realpath(__file__))
38sys.path.append(
39 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
40)
Eric Kunzee5e26762020-10-13 16:11:07 -070041import tosa_serializer as ts
42from tosa_serializer import *
43import tosa
44
45# Convenience variables to the flatc-generated types that should be enums, but aren't
46DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080047Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070048ResizeMode = tosa.ResizeMode.ResizeMode()
49
Kevin Cheng550ccc52021-03-03 11:21:43 -080050
Eric Kunzee5e26762020-10-13 16:11:07 -070051class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080052 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
53
Eric Kunzee5e26762020-10-13 16:11:07 -070054 def __init__(self):
55 pass
56
57 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010058 def getQinfo(testGen, dtype):
59 if dtype == DType.INT8:
60 return testGen.randInt(-128, 128)
61 if dtype == DType.UINT8:
62 return testGen.randInt(0, 256)
63 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070064
65 @staticmethod
66 def qgUnary(testGen, op, dtype):
67 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010068 qinfo.UnaryQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
69 TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070070 return qinfo
71
72 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010073 def qgConv(testGen, op, dtype_or_dtypeList):
Eric Kunzee5e26762020-10-13 16:11:07 -070074 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010075 if isinstance(dtype_or_dtypeList, list):
76 # a list of [input, weights, accumulator] dtypes
77 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070078 else:
Les Bell30e46802021-07-23 09:43:31 +010079 # an int, [input, weights, accumulator] dtypes are the same
80 dtypeList = [dtype_or_dtypeList] * 3
81 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
82 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
83 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -070084 return qinfo
85
86 @staticmethod
87 def qgMatmul(testGen, op, dtype):
88 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010089 qinfo.MatMulQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
90 TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070091 return qinfo
92
93 @staticmethod
94 def qgPad(testGen, op, dtype):
95 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010096 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070097 return qinfo
98
99 @staticmethod
100 def computeMultiplierAndShift(scaleFp, scale32):
101 # Derived from computeMultiplierAndShiftTosaScale32
102 # Provide a floating-point scaling factor and the scale32 parameter
103 # to compute the multiplier and shift
104
105 if scale32:
106 scaleBits = 31
107 else:
108 scaleBits = 15
109
110 m, shift = math.frexp(scaleFp)
111
112 if scaleFp < 0.0:
113 m = -m
114
115 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800116 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700117
118 if multiplier == (1 << scaleBits):
119 multiplier = multiplier // 2
120 shift = shift + 1
121
122 shift = (-shift) + scaleBits
Kevin Cheng550ccc52021-03-03 11:21:43 -0800123 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
Eric Kunzee5e26762020-10-13 16:11:07 -0700124
Kevin Cheng550ccc52021-03-03 11:21:43 -0800125 assert multiplier <= (1 << scaleBits)
126 assert shift >= 0 and shift <= 63
Eric Kunzee5e26762020-10-13 16:11:07 -0700127
128 return multiplier, shift
129
130
Kevin Cheng550ccc52021-03-03 11:21:43 -0800131class TosaTensorGen:
132 """Tensor generators create a shape list for the placeholder and const tensor
133 data operands for the operator. The actual random data is generated separately for each test."""
134
Eric Kunzee5e26762020-10-13 16:11:07 -0700135 def __init__(self):
136 pass
137
138 @staticmethod
139 def tgBasic(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800140 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700141 shape = testGen.makeShape(rank)
142
143 shape_list = []
144 for i in range(pl + const):
145 shape_list.append(shape.copy())
146
147 return shape_list
148
149 @staticmethod
150 def tgNHWC(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800151 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700152
Kevin Cheng550ccc52021-03-03 11:21:43 -0800153 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700154
155 shape = testGen.makeShape(rank)
156
157 # Constrict the batch size?
158 if testGen.args.max_batch_size:
159 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
160
161 shape_list = []
162 for i in range(pl + const):
163 shape_list.append(shape.copy())
164
165 return shape_list
166
167 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -0800168 def tgScatter(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800169 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800170
Kevin Cheng550ccc52021-03-03 11:21:43 -0800171 assert pl == 2
172 assert const == 0
173 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800174
175 values_in_shape = testGen.makeShape(rank)
176
177 # Constrict the batch size?
178 if testGen.args.max_batch_size:
179 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
180
Kevin Cheng550ccc52021-03-03 11:21:43 -0800181 W = testGen.randInt(
182 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
183 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800184 input_shape = [values_in_shape[0], W, values_in_shape[2]]
185
186 shape_list = []
187 shape_list.append(values_in_shape.copy())
188 shape_list.append(input_shape.copy())
189
190 return shape_list
191
192 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 def tgBroadcastFuzz(testGen, op, rank):
194 shape = testGen.makeShape(rank)
195
Kevin Cheng550ccc52021-03-03 11:21:43 -0800196 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700197
198 shape_list = []
199
200 # Choose one of the inputs to broadcast
201 bcast_idx = testGen.randInt(0, pl + const)
202 for i in range(pl + const):
203 shape_bcast = shape.copy()
204
205 # If the chosen input, pick a random index to broadcast
206 if i == bcast_idx:
207 fuzz_idx = testGen.randInt(0, rank)
208 shape_bcast[fuzz_idx] = 1
209
210 shape_list.append(shape_bcast)
211
212 return shape_list
213
214 @staticmethod
215 def tgConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800216 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700217
Kevin Cheng550ccc52021-03-03 11:21:43 -0800218 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700219
220 # IFM dimensions are NHWC
221 ifm_shape = testGen.makeShape(rank)
222
223 # Constrict the batch size?
224 if testGen.args.max_batch_size:
225 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
226
227 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800228 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700229
230 # Generate a random OFM depth
231 ofm_depth = testGen.makeShape(1)[0]
232
233 # The filter dimensions are OHWI
234 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
235
236 # The bias is OC
237 bias_shape = np.asarray([ofm_depth])
238
239 return [ifm_shape, filter_shape, bias_shape]
240
241 @staticmethod
242 def tgTransposeConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800243 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700244
Kevin Cheng550ccc52021-03-03 11:21:43 -0800245 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700246
247 # IFM dimensions are NHWC
248 ifm_shape = testGen.makeShape(rank)
249
250 # Constrict the batch size?
251 if testGen.args.max_batch_size:
252 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
253
254 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800255 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700256
257 # Generate a random OFM depth
258 ofm_depth = testGen.makeShape(1)[0]
259
260 # The filter dimensions are OHWI
261 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
262
Kevin Cheng989cb052021-04-28 16:29:44 -0700263 # The bias is OC
264 bias_shape = np.asarray([ofm_depth])
265
266 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700267
268 @staticmethod
269 def tgDepthwiseConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800270 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700271
Kevin Cheng550ccc52021-03-03 11:21:43 -0800272 assert rank == 4
273 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
275 # IFM dimensions are NHWC
276 ifm_shape = testGen.makeShape(rank)
277
278 # Constrict the batch size?
279 if testGen.args.max_batch_size:
280 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
281
282 # Get the filter height/width from the operator parameters
283 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800284 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
286 # Generate a random OFM depth, but don't let it get too big because
287 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800288 filter_m = (
289 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
290 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700291
292 # The filter dimensions are HWCM
293 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
294
295 # The bias is M * C
296 bias_shape = np.asarray([ifm_shape[3] * filter_m])
297
298 return [ifm_shape, filter_shape, bias_shape]
299
300 @staticmethod
301 def tgFullyConnected(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800302 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700303
Kevin Cheng550ccc52021-03-03 11:21:43 -0800304 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700305
306 input_shape = testGen.makeShape(rank)
Les Bell2a29dc62021-07-28 08:04:55 +0100307 filter_oc = (
308 testGen.rng.integers(
309 low=testGen.args.tensor_shape_range[0],
310 high=testGen.args.tensor_shape_range[1],
311 size=1,
312 )[0]
313 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700314 filter_shape = np.asarray([filter_oc, input_shape[1]])
315
316 bias_shape = np.asarray([filter_oc])
317
318 return [input_shape, filter_shape, bias_shape]
319
320 @staticmethod
321 def tgMatmul(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800322 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700323
Kevin Cheng2d60f002021-06-09 14:18:32 -0700324 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800325 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700326
327 a_shape = testGen.makeShape(rank)
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100328 # Get a random number for b_oc even if target shape is defined
329 b_oc = np.int32(
330 testGen.rng.integers(
331 low=testGen.args.tensor_shape_range[0],
332 high=testGen.args.tensor_shape_range[1],
333 size=1,
334 )
335 )[0]
336 # If N or H is large let b_oc be 1 to reduce output tensor size
337 if max(a_shape) > 1000:
338 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700339
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100340 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700341 return [a_shape, b_shape]
342
Matthew Haddon818ab902021-07-27 09:12:49 +0100343 @staticmethod
344 def tgConcat(testGen, opName, rank):
345 pl, const = opName["operands"]
346 shape = testGen.makeShape(rank)
347
348 # Create extra tensors to concat.
349 # Take into account value of pl when getting maximum number of concats
350 num_tensors = testGen.randInt(0, 4)
351 shape_list = []
352 for i in range(pl + const + num_tensors):
353 shape_list.append(shape.copy())
354
355 return shape_list
356
357 @staticmethod
358 def tgConcatConstInput(testGen, shapeList, axis):
359 # Split concat shape along axis to allow for multiple const inputs
360 # without making too many large tensors
361 shape = shapeList[0]
362 if len(shapeList) == 2 or shape[axis] < len(shapeList):
363 return shapeList
364
365 new_shapeList = [shape.copy()]
366 length_on_axis = shape[axis]
367 remaining_length = length_on_axis
368 for i in range(len(shapeList)-2):
369 # Calculate split on axis and remaining value
370 split_shape_val = int(shape[axis] / 2)
371 remaining_length = remaining_length - split_shape_val
372
373 # Append new shape, and set remaining shape
374 shape[axis] = split_shape_val
375 new_shapeList.append(shape.copy())
376 shape[axis] = remaining_length
377 if i == len(shapeList) - 3:
378 new_shapeList.append(shape.copy())
379
380 return new_shapeList
381
382
Kevin Cheng550ccc52021-03-03 11:21:43 -0800383
Eric Kunzee5e26762020-10-13 16:11:07 -0700384class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800385 """Argument generators create exhaustive or random lists of attributes for operators that take
386 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
387 tuples where the descriptive_name is appended to the test name and the arglist is expanded
388 as arguments to the operator build function."""
389
Eric Kunzee5e26762020-10-13 16:11:07 -0700390 def __init__(self):
391 pass
392
393 @staticmethod
394 def agNone(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800395 """A trivial argument generator for operators that don't take any
396 non-tensor arguments"""
397 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700398
399 @staticmethod
400 def agAxis(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800401 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700402 axes = []
403
404 shape = shapeList[0]
405
406 for a in range(0, len(shape)):
Matthew Haddon43e37192021-07-09 14:13:02 +0100407 axes.append(("axis{}".format(a), [a]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700408 return axes
409
410 @staticmethod
411 def agConv2D(testGen, opName, shapeList, dtype):
412 arg_list = []
413
414 ifm_shape = shapeList[0]
415 filter_shape = shapeList[1]
416
417 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800418 assert len(ifm_shape) == 4
419 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700420
421 maxStride = testGen.args.max_conv_stride
422 maxPadding = testGen.args.max_conv_padding + 1
423 maxDilation = testGen.args.max_conv_dilation
424
425 # Strides, padding, dilations
426 for stride in range(0, maxStride ** 2):
427 for padding in range(0, (maxPadding) ** 4):
428 for dilation in range(0, maxDilation ** 2):
429
Kevin Cheng550ccc52021-03-03 11:21:43 -0800430 s = [stride // maxStride + 1, stride % maxStride + 1]
431 p = [
432 (padding // (maxPadding * 4)) % maxPadding,
433 (padding // (maxPadding * 2)) % maxPadding,
434 (padding // (maxPadding * 1)) % maxPadding,
435 padding % maxPadding,
436 ]
437 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700438
439 # 4 padding parameters for regular conv2d
Kevin Cheng550ccc52021-03-03 11:21:43 -0800440 arg_list.append(
441 (
442 "st{}{}_pad{}{}{}{}_dilat{}{}".format(
443 s[0], s[1], p[0], p[1], p[2], p[3], d[0], d[1]
444 ),
445 [s, p, d],
446 )
447 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700448 return arg_list
449
450 @staticmethod
451 def agTransposeConv2D(testGen, opName, shapeList, dtype):
452 arg_list = []
453
454 ifm_shape = shapeList[0]
455 filter_shape = shapeList[1]
456
457 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800458 assert len(ifm_shape) == 4
459 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700460
461 maxStride = testGen.args.max_conv_stride
462 maxPadding = testGen.args.max_conv_padding + 1
463 maxDilation = testGen.args.max_conv_dilation
464
465 # Strides, padding, dilations
466 for stride in range(0, maxStride ** 2):
467 for out_padding in range(0, (maxPadding) ** 2):
468 for dilation in range(0, maxDilation ** 2):
469
Kevin Cheng550ccc52021-03-03 11:21:43 -0800470 s = [stride // maxStride + 1, stride % maxStride + 1]
471 p = [
472 (out_padding // (maxPadding * 1)) % maxPadding,
473 out_padding % maxPadding,
474 ]
475 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700476
Kevin Cheng550ccc52021-03-03 11:21:43 -0800477 oh = (
478 ifm_shape[1]
479 - filter_shape[1]
480 - (filter_shape[1] - 1) * (d[0] - 1)
481 + 2 * p[0]
482 ) // s[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700483
Kevin Cheng550ccc52021-03-03 11:21:43 -0800484 ow = (
485 ifm_shape[2]
486 - filter_shape[2]
487 - (filter_shape[2] - 1) * (d[1] - 1)
488 + 2 * p[1]
489 ) // s[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700490
491 # Output shape
Kevin Cheng550ccc52021-03-03 11:21:43 -0800492 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Eric Kunzee5e26762020-10-13 16:11:07 -0700493
Kevin Cheng550ccc52021-03-03 11:21:43 -0800494 arg_list.append(
495 (
496 "st{}{}_outpad{}{}_dilat{}{}_os{}x{}x{}x{}".format(
497 s[0],
498 s[1],
499 p[0],
500 p[1],
501 d[0],
502 d[1],
503 os[0],
504 os[1],
505 os[2],
506 os[3],
507 ),
508 [s, p, d, os],
509 )
510 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700511
512 return arg_list
513
514 @staticmethod
515 def agPad(testGen, opName, shapeList, dtype):
516 arg_list = []
517 rank = len(shapeList[0])
518
Les Bell7ffccce2021-07-28 15:37:02 +0100519 # Exhaustively test combinations of padding on each side of each dimension
520 # - the range of padding values is defined by pad_min and pad_max
521 # - for padding >9, the name format needs to be more distinctive
522 pad_min, pad_max = 0, 1
523 pad_values = [x for x in range(pad_min, pad_max + 1)]
524 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
525 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700526
Les Bell7ffccce2021-07-28 15:37:02 +0100527 for paddings in shape_pad_values:
528 name = "pad"
529 for r in range(rank):
530 before, after = paddings[r]
531 name = f"{name}{before}{after}"
532 arg_list.append((name, [np.array(paddings)]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700533
534 return arg_list
535
536 @staticmethod
537 def agPooling(testGen, opName, shapeList, dtype):
538 arg_list = []
539
540 shape = shapeList[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800541 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700542
543 maxStride = testGen.args.max_pooling_stride
544 maxKernel = testGen.args.max_pooling_kernel
545 maxPadding = testGen.args.max_pooling_padding + 1
546
547 for kernel in range(0, maxKernel ** 2):
548 for stride in range(0, maxStride ** 2):
549 for padding in range(0, maxPadding ** 4):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800550 s = [stride // maxStride + 1, stride % maxStride + 1]
551 k = [(kernel // maxKernel) + 2, (kernel % maxKernel) + 2]
552 p = [
553 (padding // (maxPadding * 4)) % maxPadding,
554 (padding // (maxPadding * 2)) % maxPadding,
555 (padding // (maxPadding * 1)) % maxPadding,
556 padding % maxPadding,
557 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700558
Kevin Cheng550ccc52021-03-03 11:21:43 -0800559 arg_list.append(
560 (
561 "st{}{}_kern{}{}_pad{}{}{}{}".format(
562 s[0], s[1], k[0], k[1], p[0], p[1], p[2], p[3]
563 ),
564 [k, s, p],
565 )
566 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700567 return arg_list
568
569 @staticmethod
570 def agCast(testGen, opName, shapeList, inDtype):
571 arg_list = []
572
573 # Enumerate the output types here
574 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800575 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700576 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800577 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700578 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800579 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700580 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800581 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700582 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800583 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700584 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800585 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700586
587 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800588 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700589
590 return arg_list
591
592 @staticmethod
593 def agRescale(testGen, opName, shapeList, inDtype):
594 arg_list = []
595
596 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100597 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
598 if inDtype == DType.UINT8 and dtype != DType.INT8:
599 # The only output dtype for UINT8 is INT8, skip all other combinations
600 continue
601 if inDtype != DType.INT8 and dtype == DType.UINT8:
602 # The only input dtype for UINT8 is INT8, skip all other combinations
603 continue
604
Kevin Cheng550ccc52021-03-03 11:21:43 -0800605 for scale32 in [False, True]:
606 for double_round in [False, True]:
607 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700608
609 if inDtype == DType.INT48 and scale32:
610 # Illegal condition. Must be scale32=False
611 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100612 if double_round and not scale32:
613 # Illegal condition. ERROR_IF(!scale32 && double_round)
614 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700615
Kevin Cheng550ccc52021-03-03 11:21:43 -0800616 arg_list.append(
617 (
618 "out{}_sc{}_dr{}_pc{}".format(
619 DTypeNames[dtype],
620 int(scale32),
621 int(double_round),
622 int(per_channel),
623 ),
624 [dtype, scale32, double_round, per_channel],
625 )
626 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700627
628 return arg_list
629
Kevin Chengaee1fac2020-11-11 13:54:06 -0800630 @staticmethod
631 def agMul(testGen, opName, shapeList, dtype):
632 arg_list = []
633
634 if dtype is DType.INT32:
635 for p in range(testGen.args.num_rand_permutations):
636
637 shift = testGen.randInt(0, 32)
638
Kevin Cheng550ccc52021-03-03 11:21:43 -0800639 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800640 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100641 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800642
643 return arg_list
644
645 @staticmethod
646 def agArithmeticRightShift(testGen, opName, shapeList, dtype):
647 arg_list = []
648
Kevin Cheng550ccc52021-03-03 11:21:43 -0800649 arg_list.append(("roundTrue", [True]))
650 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800651
652 return arg_list
653
Eric Kunzee5e26762020-10-13 16:11:07 -0700654 # Helper function for reshape. Gets some factors of a larger number.
655 @staticmethod
656 def getFactors(val, start=1):
657 factors = []
658
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100659 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700660 if (val % i) == 0:
661 factors.append(i)
662
663 return factors
664
665 @staticmethod
666 def agReshape(testGen, opName, shapeList, dtype):
667 arg_list = []
668
669 origShape = shapeList[0]
670
671 totalElements = 1
672 for s in origShape:
673 totalElements *= s
674
675 # This code is NOT fast. Fortunately, the numbers are fairly small.
676 factors = TosaArgGen.getFactors(totalElements)
677
678 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100679 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800680 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700681 continue
682
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100683 found = True
684 # escape_counter breaks while loop if it continues on for too long
685 escape_counter = 0
686 while found:
687 newShape = []
688 # Generate newShape ensuring it isn't a duplicate
689 remainingElements = totalElements
690 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100691 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100692 # pick rank-1 factors
693 newShape.append(shuffledFactors[0])
694 remainingElements = remainingElements // shuffledFactors[0]
695 shuffledFactors = testGen.rng.permutation(
696 TosaArgGen.getFactors(remainingElements)
697 )
698 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700699
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100700 # Toss in a -1 sometimes
701 minusOne = testGen.randInt(0, newRank * 4)
702 if minusOne < newRank:
703 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700704
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100705 # Check for duplicates
706 found = False
707 for name, other_shape in arg_list:
708 if other_shape[0] == newShape:
709 found = True
710 break
711
712 escape_counter += 1
713 if escape_counter >= 100:
714 break
715
716 if not found:
717 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700718
719 return arg_list
720
Eric Kunzee5e26762020-10-13 16:11:07 -0700721 @staticmethod
722 def agTranspose(testGen, opName, shapeList, dtype):
723 arg_list = []
724
725 ifm_shape = shapeList[0]
726
Jeremy Johnsona6185572021-06-21 15:55:35 +0100727 # Get all permutations
728 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700729
Jeremy Johnsona6185572021-06-21 15:55:35 +0100730 # Limit to possible permutations from shape dimension or argument setting
731 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700732
Jeremy Johnsona6185572021-06-21 15:55:35 +0100733 # Get random permutation generator that uses all permutations
734 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700735
Jeremy Johnsona6185572021-06-21 15:55:35 +0100736 # Create list of required amount of permutations
737 arg_list = [("perm{}".format(p), [random_permutations[p].tolist()]) for p in range(limit)]
Eric Kunzee5e26762020-10-13 16:11:07 -0700738 return arg_list
739
740 @staticmethod
741 def agSlice(testGen, opName, shapeList, dtype):
742 arg_list = []
743
744 ifm_shape = shapeList[0]
745 rank = len(ifm_shape)
746
747 for p in range(testGen.args.num_rand_permutations):
748 begin = []
749 size = []
750
Kevin Cheng550ccc52021-03-03 11:21:43 -0800751 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700752
753 for i in range(rank):
754 if ifm_shape[i] > 1:
755 begin.append(testGen.randInt(0, ifm_shape[i]))
756 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
757
758 # Invalid slice size?
759 if size[i] == 0:
760 valid = False
761 else:
762 begin.append(0)
763 size.append(1)
764
765 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800766 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700767 return arg_list
768
769 @staticmethod
770 def agTile(testGen, opName, shapeList, dtype):
771 arg_list = []
772
773 ifm_shape = shapeList[0]
774 rank = len(ifm_shape)
775
776 for p in range(testGen.args.num_rand_permutations):
777
778 # Pick a few random, but small multiple values
779 # because otherwise this has a tendency to generate
780 # enormous tensors
781 multiples = []
782 for i in range(rank):
783 multiples.append(testGen.randInt(1, 4))
784
Kevin Cheng550ccc52021-03-03 11:21:43 -0800785 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700786
787 return arg_list
788
789 @staticmethod
790 def agResize(testGen, opName, shapeList, dtype):
791 arg_list = []
792
793 ifm_shape = shapeList[0]
794
795 for m in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
796
797 # Exclude illegal {mode, type} configurations. Pick legal output types
798 if m == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100799 outputDTypeList = [DType.INT8]
Eric Kunzee5e26762020-10-13 16:11:07 -0700800 elif m == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800801 outputDTypeList = [DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -0700802 elif m == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100803 outputDTypeList = [DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700804 elif m == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800805 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800806 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800807 outputDTypeList = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700808 else:
809 continue
810
811 for outputDType in outputDTypeList:
812 for perm in range(testGen.args.num_rand_permutations):
813
814 # Randomly generate legal output dimensions and shift
815 # and then compute the stride and offset based on them
Kevin Cheng550ccc52021-03-03 11:21:43 -0800816 output_dims = [testGen.randInt(1), testGen.randInt(1)]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800817 in_center_h = (ifm_shape[1] - 1) / 2.0
818 in_center_w = (ifm_shape[2] - 1) / 2.0
819 out_center_h = (output_dims[0] - 1) / 2.0
820 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700821
Kevin Cheng77d0f762020-11-24 10:26:32 -0800822 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
823 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
824 fp_offset_y = in_center_h - fp_stride_y * out_center_h
825 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -0700826
Kevin Cheng77d0f762020-11-24 10:26:32 -0800827 if outputDType == DType.FLOAT:
828 shift = 0
829 stride = [0, 0]
830 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800831 stride_fp = [fp_stride_y, fp_stride_x]
832 offset_fp = [fp_offset_y, fp_offset_x]
833 arg_list.append(
834 (
835 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Les Bell33d837e2021-08-10 08:34:43 +0100836 "N" if m == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800837 output_dims[0],
838 output_dims[1],
839 testGen.typeStr(outputDType),
840 stride_fp[0],
841 stride_fp[1],
842 offset_fp[0],
843 offset_fp[1],
844 ),
845 [
846 m,
847 stride,
848 offset,
849 shift,
850 stride_fp,
851 offset_fp,
852 output_dims,
853 dtype,
854 outputDType,
855 ],
856 )
857 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800858 else:
859 shift = 11
860 unit = float(1 << shift)
861 stride_y = int(round(fp_stride_y * unit))
862 stride_x = int(round(fp_stride_x * unit))
863 offset_y = int(round(fp_offset_y * unit))
864 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700865
Kevin Cheng550ccc52021-03-03 11:21:43 -0800866 while (
867 stride_y >= 32768
868 or stride_x >= 32768
869 or offset_y >= 32768
870 or offset_x >= 32768
871 or offset_y < -32768
872 or offset_x < -32768
873 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -0800874 shift = shift - 1
875 unit = float(1 << shift)
876 stride_y = int(round(fp_stride_y * unit))
877 stride_x = int(round(fp_stride_x * unit))
878 offset_y = int(round(fp_offset_y * unit))
879 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700880
Kevin Cheng550ccc52021-03-03 11:21:43 -0800881 stride = [stride_y, stride_x]
882 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800883
884 stride_fp = [0.0, 0.0]
885 offset_fp = [0.0, 0.0]
886
Kevin Cheng550ccc52021-03-03 11:21:43 -0800887 arg_list.append(
888 (
889 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Les Bell33d837e2021-08-10 08:34:43 +0100890 "N" if m == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800891 shift,
892 output_dims[0],
893 output_dims[1],
894 testGen.typeStr(outputDType),
895 stride[0],
896 stride[1],
897 offset[0],
898 offset[1],
899 ),
900 [
901 m,
902 stride,
903 offset,
904 shift,
905 stride_fp,
906 offset_fp,
907 output_dims,
908 dtype,
909 outputDType,
910 ],
911 )
912 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700913
914 return arg_list
915
916 def agCondIf(testGen, opName, shapeList, dtype):
917 # CondIf generates the condition values here.
918 # Convert to tensors in the build function, along with the
919 # then and else blocks
920 arg_list = []
921
922 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800923 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700924
925 return arg_list
926
927 def agWhileLoop(testGen, opName, shapeList, dtype):
928 # While loop: 0 iterations, 1, more than 1
929 arg_list = []
930
931 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800932 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700933
934 return arg_list
935
Kevin Cheng550ccc52021-03-03 11:21:43 -0800936
Eric Kunzee5e26762020-10-13 16:11:07 -0700937class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +0100938 # Maximum rank of tensor supported by test generator.
939 TOSA_TENSOR_MAX_RANK = 6
940
Eric Kunzee5e26762020-10-13 16:11:07 -0700941 def __init__(self, args):
942 self.args = args
943 self.basePath = args.output_dir
944 self.random_seed = args.random_seed
945 self.ser = None
946 self.rng = np.random.default_rng(self.random_seed)
947 self.createDynamicOpLists()
948 self.initOpListDefaults()
949 self.quantGen = TosaQuantGen()
950 # Force makeShape to do a specific starting shape
951 self.targetted_shape = None
952
953 def createSerializer(self, opName, testPath):
954 self.testPath = os.path.join(opName, testPath)
955
956 fullPath = os.path.join(self.basePath, self.testPath)
957 os.makedirs(fullPath, exist_ok=True)
958 self.ser = ts.TosaSerializer(fullPath)
959
960 def getSerializer(self):
961 return self.ser
962
963 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800964 with open(
965 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
966 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700967 fd.write(self.ser.serialize())
968
Kevin Cheng550ccc52021-03-03 11:21:43 -0800969 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
970 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700971
972 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -0700973 if dtype == DType.BOOL:
974 np_dt = np.bool
975 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -0700976 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700977 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700978 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700979 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100980 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
981 elif dtype == DType.UINT8:
982 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700983 elif dtype == DType.INT16:
984 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
985 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800986 return np.int32(
987 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
988 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700989 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800990 return np.int64(
991 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
992 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700993 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100994 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700995 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800996 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700997
Kevin Cheng989cb052021-04-28 16:29:44 -0700998 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700999 placeholders = []
1000
Kevin Cheng989cb052021-04-28 16:29:44 -07001001 assert len(shape_list) == len(dtype_list)
1002
1003 for idx, shape in enumerate(shape_list):
1004 arr = self.getRandTensor(shape, dtype_list[idx])
1005 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001006
1007 return placeholders
1008
Kevin Cheng989cb052021-04-28 16:29:44 -07001009 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07001010 consts = []
1011
Kevin Cheng989cb052021-04-28 16:29:44 -07001012 assert len(shape_list) == len(dtype_list)
1013
1014 for idx, shape in enumerate(shape_list):
1015 arr = self.getRandTensor(shape, dtype_list[idx])
1016 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001017
1018 return consts
1019
1020 def makeShape(self, rank):
1021 if self.targetted_shape:
1022 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001023 return np.int32(
1024 self.rng.integers(
1025 low=self.args.tensor_shape_range[0],
1026 high=self.args.tensor_shape_range[1],
1027 size=rank,
1028 )
1029 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001030
1031 def setTargetShape(self, shape):
1032 self.targetted_shape = shape
1033
1034 def randInt(self, low=0, high=256):
1035 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
1036
1037 def getRandNumberDType(self, dtype):
1038 if dtype == DType.FLOAT:
1039 return self.rng.random()
1040 elif dtype == DType.BOOL:
1041 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07001042 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07001043 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07001044 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001045 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001046 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07001047 elif dtype == DType.INT16:
1048 low, high = (-32768, 32768)
1049 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001050 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07001051 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001052 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07001053 # Special size
1054 return np.int64(self.rng.integers(low, high, size=1))[0]
1055 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001056 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001057
1058 return np.int32(self.rng.integers(low, high, size=1))[0]
1059
1060 def shapeStr(self, shape):
1061
1062 sStr = []
1063 # Convert to strings
1064 for i in shape:
1065 sStr.append(str(i))
1066
Kevin Cheng550ccc52021-03-03 11:21:43 -08001067 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001068
1069 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07001070 if isinstance(t, list):
1071 assert len(t) >= 2
1072 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001073 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001074 if t == DType.BOOL:
1075 return "b"
1076 elif t == DType.INT4:
1077 return "i4"
1078 elif t == DType.INT8:
1079 return "i8"
1080 elif t == DType.UINT8:
1081 return "u8"
1082 elif t == DType.INT16:
1083 return "i16"
1084 elif t == DType.INT32:
1085 return "i32"
1086 elif t == DType.INT48:
1087 return "i48"
1088 elif t == DType.FLOAT:
1089 return "float"
1090 else:
1091 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001092
1093 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001094 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08001095 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07001096 return 4
1097 elif t == DType.INT8:
1098 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08001099 elif t == DType.UINT8:
1100 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07001101 elif t == DType.INT16:
1102 return 16
1103 elif t == DType.INT32:
1104 return 32
1105 elif t == DType.INT48:
1106 return 48
1107 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001108 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001109
1110 # Argument generators
1111 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
1112 # Where the string descriptor is used to generate the test name and
1113 # The build_fcn_arg_list is expanded and passed to the operator test
1114 # build function
1115
Kevin Cheng550ccc52021-03-03 11:21:43 -08001116 def build_unary(self, op, a, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001117 result_tens = OutputShaper.unaryOp(self.ser, a)
1118 self.ser.addOperator(op, [a.name], [result_tens.name], None, qinfo)
1119 return result_tens
1120
1121 def build_binary_broadcast(self, op, a, b):
1122 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1123 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1124 return result_tens
1125
1126 def build_binary_nonbroadcast(self, op, a, b):
1127 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
1128 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1129 return result_tens
1130
Kevin Chengaee1fac2020-11-11 13:54:06 -08001131 def build_arithmetic_right_shift(self, op, a, b, round):
1132 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1133
1134 attr = ts.TosaSerializerAttribute()
1135 attr.ArithmeticRightShiftAttribute(round)
1136
1137 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1138 return result_tens
1139
1140 def build_mul(self, op, a, b, shift):
Eric Kunzee5e26762020-10-13 16:11:07 -07001141 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1142
1143 # Special for multiply:
1144 # Force the result to INT32 for INT types
1145 if a.dtype != DType.FLOAT:
1146 result_tens.setDtype(DType.INT32)
1147
Kevin Chengaee1fac2020-11-11 13:54:06 -08001148 attr = ts.TosaSerializerAttribute()
1149 attr.MulAttribute(shift)
1150
1151 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001152 return result_tens
1153
1154 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001155 # Constant size depending on type, random values
1156 if a.dtype == DType.INT16:
1157 table_dtype = DType.INT16
1158 table_arr = self.getRandTensor([513], table_dtype)
1159 else:
1160 assert a.dtype == DType.INT8
1161 table_dtype = DType.INT8
1162 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001163
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001164 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
1165 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001166 self.ser.addOperator(op, [a.name, table_tens.name], [result_tens.name], None)
1167
1168 return result_tens
1169
1170 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07001171 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
1172 self.ser.addOperator(op, [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001173 return result_tens
1174
1175 def build_comparison(self, op, a, b):
1176 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
1177 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1178 return result_tens
1179
1180 def build_argmax(self, op, a, axis):
1181 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
1182
1183 attr = ts.TosaSerializerAttribute()
1184 attr.AxisAttribute(axis)
1185
1186 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1187 return result_tens
1188
Kevin Cheng550ccc52021-03-03 11:21:43 -08001189 def build_pool2d(self, op, input, kernel, stride, pad, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001190 result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
1191
1192 attr = ts.TosaSerializerAttribute()
1193 attr.Pool2dAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07001194
1195 self.ser.addOperator(op, [input.name], [result_tens.name], attr, qinfo)
1196 return result_tens
1197
1198 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001199 assert len(padding) == 4
1200 result_tens = OutputShaper.conv2dOp(
1201 self.ser, ifm, filter, strides, padding, dilations
1202 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001203
1204 attr = ts.TosaSerializerAttribute()
1205 attr.Conv2dAttribute(padding, strides, dilations)
1206
Kevin Cheng550ccc52021-03-03 11:21:43 -08001207 self.ser.addOperator(
1208 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1209 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001210 return result_tens
1211
Kevin Cheng550ccc52021-03-03 11:21:43 -08001212 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07001213 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001214 ):
1215 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07001216 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
1217
1218 attr = ts.TosaSerializerAttribute()
1219 attr.TransposeConv2DAttribute(outpad, stride, dilation, output_shape)
1220
Kevin Cheng550ccc52021-03-03 11:21:43 -08001221 self.ser.addOperator(
Kevin Cheng989cb052021-04-28 16:29:44 -07001222 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001223 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001224 return result_tens
1225
Kevin Cheng550ccc52021-03-03 11:21:43 -08001226 def build_depthwise_conv2d(
1227 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
1228 ):
1229 result_tens = OutputShaper.depthwiseConv2dOp(
1230 self.ser, ifm, filter, strides, padding, dilations
1231 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001232
1233 attr = ts.TosaSerializerAttribute()
1234 attr.Conv2dAttribute(padding, strides, dilations)
1235
Kevin Cheng550ccc52021-03-03 11:21:43 -08001236 self.ser.addOperator(
1237 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1238 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001239 return result_tens
1240
1241 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
1242 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
1243
Kevin Cheng550ccc52021-03-03 11:21:43 -08001244 self.ser.addOperator(
1245 op, [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
1246 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001247 return result_tens
1248
1249 def build_matmul(self, op, a, b, qinfo):
1250 result_tens = OutputShaper.matmulOp(self.ser, a, b)
1251 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], None, qinfo)
1252 return result_tens
1253
1254 def build_reduce(self, op, a, axis):
1255 result_tens = OutputShaper.reduceOp(self.ser, a, axis)
1256
1257 attr = ts.TosaSerializerAttribute()
1258 attr.AxisAttribute(axis)
1259
1260 self.ser.addOperator(op, [a.name], result_tens.name, attr)
1261 return result_tens
1262
1263 def build_clamp(self, op, a):
1264 result_tens = OutputShaper.unaryOp(self.ser, a)
1265
1266 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01001267 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001268
1269 if a.dtype == DType.FLOAT:
1270 attr.ClampAttribute(0, 0, min(v), max(v))
1271 else:
1272 attr.ClampAttribute(min(v), max(v), 0, 0)
1273
1274 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1275 return result_tens
1276
1277 def build_leaky_relu(self, op, a):
1278 result_tens = OutputShaper.unaryOp(self.ser, a)
1279 attr = ts.TosaSerializerAttribute()
1280
1281 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
1282
1283 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1284 return result_tens
1285
1286 # Needs an additional type/input
1287 def build_prelu(self, op, a):
1288 result_tens = OutputShaper.unaryOp(self.ser, a)
1289
1290 self.ser.addOperator(op, [a.name], [result_tens.name])
1291 return result_tens
1292
1293 def build_relun(self, op, a):
1294 result_tens = OutputShaper.unaryOp(self.ser, a)
1295
1296 attr = ts.TosaSerializerAttribute()
1297
1298 if a.dtype == DType.FLOAT:
1299 attr.ReluNAttribute(0, self.getRandNumberDType(a.dtype))
1300 else:
1301 attr.ReluNAttribute(self.getRandNumberDType(a.dtype), 0)
1302
1303 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1304 return result_tens
1305
1306 def build_sigmoid(self, op, a):
1307 result_tens = OutputShaper.unaryOp(self.ser, a)
1308 self.ser.addOperator(op, [a.name], [result_tens.name])
1309 return result_tens
1310
1311 def build_tanh(self, op, a):
1312 result_tens = OutputShaper.unaryOp(self.ser, a)
1313 self.ser.addOperator(op, [a.name], [result_tens.name])
1314 return result_tens
1315
Matthew Haddon818ab902021-07-27 09:12:49 +01001316 def build_concat(self, op, *a):
1317 assert (type(a[-1]) == int)
1318
1319 # To store variable length list of input tensors we need to store axis along with it
1320 axis = a[-1]
1321 a = a[:-1]
1322
1323 result_tens = OutputShaper.concatOp(self.ser, axis, *a)
Eric Kunzee5e26762020-10-13 16:11:07 -07001324
1325 attr = ts.TosaSerializerAttribute()
1326 attr.AxisAttribute(axis)
1327
Matthew Haddon818ab902021-07-27 09:12:49 +01001328 input_tensor_names = []
1329 for tensor in a:
1330 input_tensor_names.append(tensor.name)
1331
1332 self.ser.addOperator(op, input_tensor_names, [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001333
1334 def build_pad(self, op, a, padding, qinfo):
1335 result_tens = OutputShaper.padOp(self.ser, a, padding)
1336
1337 # Need to turn the padding array into a TOSA tensor here.
1338 # This is one of the few tensor operands that does not get
1339 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08001340 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07001341
Kevin Cheng550ccc52021-03-03 11:21:43 -08001342 self.ser.addOperator(
1343 op, [a.name, padding_tens.name], [result_tens.name], None, qinfo
1344 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001345
1346 def build_reshape(self, op, a, newShape):
1347 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
1348
1349 attr = ts.TosaSerializerAttribute()
1350 attr.ReshapeAttribute(newShape)
1351
1352 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1353 return result_tens
1354
1355 def build_reverse(self, op, a, axis):
1356 result_tens = OutputShaper.unaryOp(self.ser, a)
1357
1358 attr = ts.TosaSerializerAttribute()
1359 attr.AxisAttribute(axis)
1360
1361 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1362 return result_tens
1363
1364 def build_transpose(self, op, a, perms):
1365 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
1366
Kevin Cheng550ccc52021-03-03 11:21:43 -08001367 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07001368
1369 self.ser.addOperator(op, [a.name, perms_tens.name], [result_tens.name])
1370 return result_tens
1371
1372 def build_slice(self, op, a, begin, size):
1373 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
1374
1375 attr = ts.TosaSerializerAttribute()
1376 attr.SliceAttribute(begin, size)
1377
1378 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1379 return result_tens
1380
1381 def build_tile(self, op, a, multiples):
1382 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
1383
1384 attr = ts.TosaSerializerAttribute()
1385 attr.TileAttribute(multiples)
1386
1387 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1388 return result_tens
1389
Kevin Cheng77d0f762020-11-24 10:26:32 -08001390 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07001391
1392 # Create a new indicies tensor
1393 # here with data that doesn't exceed the dimensions of the values tensor
1394
Kevin Cheng550ccc52021-03-03 11:21:43 -08001395 K = values.shape[1] # K
1396 W = self.randInt(
1397 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1398 ) # W
1399 indicies_arr = np.int32(
1400 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1401 ) # (N, W)
1402 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001403
Kevin Cheng77d0f762020-11-24 10:26:32 -08001404 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07001405
Kevin Cheng77d0f762020-11-24 10:26:32 -08001406 self.ser.addOperator(op, [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001407
1408 return result_tens
1409
Kevin Cheng77d0f762020-11-24 10:26:32 -08001410 def build_scatter(self, op, values_in, input):
1411
1412 # Create a new indicies tensor
1413 # here with data that doesn't exceed the dimensions of the values_in tensor
1414
Kevin Cheng550ccc52021-03-03 11:21:43 -08001415 K = values_in.shape[1] # K
1416 W = input.shape[1] # W
1417 indicies_arr = np.int32(
1418 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1419 ) # (N, W)
1420 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001421
1422 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
1423
Kevin Cheng550ccc52021-03-03 11:21:43 -08001424 self.ser.addOperator(
1425 op, [values_in.name, indicies.name, input.name], [result_tens.name]
1426 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001427
1428 return result_tens
1429
Kevin Cheng550ccc52021-03-03 11:21:43 -08001430 def build_resize(
1431 self,
1432 op,
1433 input,
1434 mode,
1435 stride,
1436 offset,
1437 shift,
1438 stride_fp,
1439 offset_fp,
1440 output_dims,
1441 input_dtype,
1442 output_dtype,
1443 ):
1444 result_tens = OutputShaper.resizeOp(
1445 self.ser,
1446 input,
1447 mode,
1448 stride,
1449 offset,
1450 shift,
1451 stride_fp,
1452 offset_fp,
1453 output_dims,
1454 input_dtype,
1455 output_dtype,
1456 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001457
1458 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001459
Kevin Cheng550ccc52021-03-03 11:21:43 -08001460 attr.ResizeAttribute(
1461 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
1462 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001463
1464 self.ser.addOperator(op, [input.name], [result_tens.name], attr)
1465 return result_tens
1466
1467 def build_identityn(self, op, val, val2):
1468
Kevin Cheng550ccc52021-03-03 11:21:43 -08001469 result_tens = OutputShaper.unaryOp(self.ser, val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001470 result_tens2 = OutputShaper.unaryOp(self.ser, val2)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001471 self.ser.addOperator(
1472 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1473 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001474 return result_tens
1475
1476 def build_placeholder(self, op, val):
1477 # Add an identity op to avoid warning in the reference model
1478 return self.build_unary(Op.IDENTITY, val)
1479
1480 # Type Conversion
1481 def build_cast(self, op, val, out_dtype):
1482 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1483 self.ser.addOperator(op, [val.name], [result_tens.name])
1484 return result_tens
1485
1486 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
1487 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1488
1489 if per_channel:
1490 nc = val.shape[-1]
1491 else:
1492 nc = 1
1493
1494 in_type_width = self.typeWidth(val.dtype)
1495 out_type_width = self.typeWidth(out_dtype)
1496
Kevin Cheng3a478572021-01-22 17:21:02 -08001497 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001498 input_zp = self.randInt(-128, 128)
1499 in_type_width = in_type_width + 1
1500 elif val.dtype == DType.UINT8:
1501 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07001502 in_type_width = in_type_width + 1
1503 else:
1504 input_zp = 0
1505
Kevin Cheng3a478572021-01-22 17:21:02 -08001506 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001507 output_zp = self.randInt(-128, 128)
1508 out_type_width = out_type_width + 1
1509 elif out_dtype == DType.UINT8:
1510 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07001511 out_type_width = out_type_width + 1
1512 else:
1513 output_zp = 0
1514
1515 # Calculate scale based on:
1516 # scale = a *(2^output_width)/(2^input_width))
1517
1518 a = np.float32(self.rng.random(size=[nc]))
1519 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1520
1521 if scale32:
1522 pass
1523 # Cap the scaling at 2^15 - 1 for scale16
1524 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1525 else:
1526 # Cap the scaling at 2^15 - 1 for scale16
1527 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1528
Kevin Cheng550ccc52021-03-03 11:21:43 -08001529 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001530
1531 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1532 shift_arr = np.int32(np.zeros(shape=[nc]))
1533
1534 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001535 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1536 scale_arr[i], scale32
1537 )
Kevin Chengaee1fac2020-11-11 13:54:06 -08001538 if shift_arr[i] < 2 or shift_arr[i] > 62:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001539 self.ser.setExpectedFailure(True, "OpRescale: invalid shift value")
Eric Kunzee5e26762020-10-13 16:11:07 -07001540
Kevin Cheng550ccc52021-03-03 11:21:43 -08001541 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07001542
1543 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001544 attr.RescaleAttribute(
1545 input_zp,
1546 output_zp,
1547 multiplier_arr,
1548 shift_arr,
1549 scale32,
1550 double_round,
1551 per_channel,
1552 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001553
1554 self.ser.addOperator(op, [val.name], [result_tens.name], attr)
1555 return result_tens
1556
1557 def build_cond_if_const(self, op, then_tens, else_tens, cond):
1558 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1559 # (except for the generated shap) and the condition. Build Then/Else blocks
1560 # and fill them with const nodes for the body.
1561
1562 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001563 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001564
1565 # Make then/else tensors
1566 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01001567 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1568 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001569
1570 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001571 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001572
1573 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001574 then_block = "THEN_BLOCK"
1575 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001576 attr = ts.TosaSerializerAttribute()
1577 attr.CondIfAttribute(then_block, else_block)
1578
1579 # Finally, build the op and the two blocks
1580 self.ser.addOperator(op, [cond_tens.name], [result_tens.name], attr)
1581
1582 self.ser.startBasicBlock(then_block)
1583 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001584 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001585 self.ser.addOutputTensor(then_tens)
1586
1587 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001588 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001589 self.ser.addOutputTensor(else_tens)
1590
1591 return result_tens
1592
1593 def build_cond_if_binary(self, op, a, b, cond):
1594 # For cond_if with a binary op in the then/else blocks, take a and b and
1595 # alternately add or subtract them based on the condition
1596
1597 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001598 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001599
Kevin Cheng550ccc52021-03-03 11:21:43 -08001600 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001601 self.ser.currBasicBlock.addOutput(result_tens.name)
1602
1603 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001604 then_block = "THEN_BLOCK"
1605 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001606 attr = ts.TosaSerializerAttribute()
1607 attr.CondIfAttribute(then_block, else_block)
1608
1609 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001610 self.ser.addOperator(
1611 op, [cond_tens.name, a.name, b.name], [result_tens.name], attr
1612 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001613
1614 self.ser.startBasicBlock(then_block)
1615 self.ser.addInputTensor(a)
1616 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001617 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001618 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
1619
1620 self.ser.startBasicBlock(else_block)
1621 self.ser.addInputTensor(a)
1622 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001623 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001624 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
1625
1626 return result_tens
1627
1628 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001629 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001630
Kevin Cheng550ccc52021-03-03 11:21:43 -08001631 cond_block = "COND_BLOCK"
1632 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001633
1634 attr = ts.TosaSerializerAttribute()
1635 attr.WhileLoopAttribute(cond_block, body_block)
1636
1637 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001638 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001639 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001640 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001641
1642 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001643 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1644 a_out = self.ser.addIntermediate(a.shape, a.dtype)
1645 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001646
1647 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001648 self.ser.addOperator(
1649 op,
1650 [iter.name, a.name, acc.name],
1651 [iter_out.name, a_out.name, acc_out.name],
1652 attr,
1653 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001654
1655 # COND block (input: iter, output: cond_tens )
1656 self.ser.startBasicBlock(cond_block)
1657 self.ser.addInputTensor(iter)
1658 self.ser.addInputTensor(a)
1659 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001660 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
1661 cond_tens = self.ser.addOutput([], DType.BOOL)
1662 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001663
1664 # BODY block (input: a, acc, iter, output: a, acc, iter)
1665 # Note that local intermediate tensors need to be declared here for the outputs
1666 self.ser.startBasicBlock(body_block)
1667 self.ser.addInputTensor(iter)
1668 self.ser.addInputTensor(a)
1669 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001670 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
1671 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1672 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001673 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
1674 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
1675 self.ser.addOutputTensor(iter_body_out)
1676 self.ser.addOutputTensor(a)
1677 self.ser.addOutputTensor(acc_body_out)
1678
1679 return acc_out
1680
Kevin Cheng550ccc52021-03-03 11:21:43 -08001681 def genOpTestList(
1682 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None
1683 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001684
1685 try:
1686 op = self.TOSA_OP_LIST[opName]
1687 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001688 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001689
1690 # Initialize a new random number generator
1691 self.rng = np.random.default_rng(self.random_seed)
1692
Kevin Cheng550ccc52021-03-03 11:21:43 -08001693 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001694
1695 # Generate the lists of arguments
Kevin Cheng550ccc52021-03-03 11:21:43 -08001696 rmin, rmax = op["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001697
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001698 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
1699 default_test_rank_range = range(1, 5)
1700
Eric Kunzee5e26762020-10-13 16:11:07 -07001701 # Test list consists of a tuple of:
1702 # (opName, testNameStr, dtype, shapeList, argumentsList)
1703 testList = []
1704
1705 if not shapeFilter:
1706 shapeFilter = [None]
1707
1708 for r in range(rmin, rmax + 1):
1709
1710 # Filter out the rank?
1711 if rankFilter is not None and r not in rankFilter:
1712 continue
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001713 if rankFilter is None and shapeFilter[0] is None and r not in default_test_rank_range:
1714 continue
Eric Kunzee5e26762020-10-13 16:11:07 -07001715
Kevin Cheng550ccc52021-03-03 11:21:43 -08001716 for t in op["types"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001717
1718 # Filter tests based on dtype?
1719 if dtypeFilter is not None:
Les Bell30e46802021-07-23 09:43:31 +01001720 if not (t in dtypeFilter or (isinstance(t, list) and t[0] in dtypeFilter)):
Eric Kunzee5e26762020-10-13 16:11:07 -07001721 continue
1722
1723 # Create the placeholder and const tensors
1724 for shape in shapeFilter:
1725 # A None shape chooses a random shape of a given rank
1726
1727 # Filter out by rank
1728 if shape is not None and len(shape) != r:
1729 continue
1730
1731 self.setTargetShape(shape)
1732 shapeList = tgen_fcn(self, op, r)
1733
1734 shapeStr = self.shapeStr(shapeList[0])
1735 typeStr = self.typeStr(t)
1736
1737 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
1738 argList = []
1739 if agen_fcn:
1740 argList = agen_fcn(self, opName, shapeList, t)
1741 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001742 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07001743
1744 for argStr, args in argList:
1745 if argStr:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001746 testStr = "{}_{}_{}_{}".format(
1747 opName, shapeStr, typeStr, argStr
1748 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001749 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001750 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001751
1752 testList.append((opName, testStr, t, shapeList, args))
1753
1754 return testList
1755
Kevin Cheng989cb052021-04-28 16:29:44 -07001756 def serializeTest(self, opName, testStr, dtype_or_dtypeList, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07001757 try:
1758 op = self.TOSA_OP_LIST[opName]
1759 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001760 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001761
1762 # Create a serializer
1763 self.createSerializer(opName, testStr)
1764
Kevin Cheng550ccc52021-03-03 11:21:43 -08001765 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
1766 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07001767 num_operands = pCount + cCount
1768
1769 if isinstance(dtype_or_dtypeList, list):
1770 dtypeList = dtype_or_dtypeList
Matthew Haddon818ab902021-07-27 09:12:49 +01001771 elif op['op'] == Op.CONCAT:
1772 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07001773 else:
1774 dtypeList = [dtype_or_dtypeList] * (num_operands)
1775
Matthew Haddon818ab902021-07-27 09:12:49 +01001776 if op['op'] != Op.CONCAT:
1777 assert (
1778 len(shapeList) == num_operands
1779 ), "shapeList length {} must match number of operands {}".format(
1780 len(shapeList), num_operands
1781 )
1782 assert (
1783 len(dtypeList) == num_operands
1784 ), "dtypeList length {} must match number of operands {}".format(
1785 len(dtypeList), num_operands
1786 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001787
1788 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001789 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001790 except KeyError:
1791 qgen = None
1792
1793 # Build the random tensor operands and the test
1794 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08001795
1796 # If test is ArithmeticRightShift, force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001797 if op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
1798 assert (
1799 pCount == 2 and cCount == 0
1800 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08001801
1802 placeholders = []
1803 for idx, shape in enumerate(shapeList[:]):
1804 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07001805 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001806 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001807 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001808 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001809 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001810 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
1811 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001812 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08001813 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001814 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07001815 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001816
1817 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01001818 elif op["op"] == Op.SELECT:
1819 # Set datatype of condition tensor to boolean
1820 dtypeList[0] = DType.BOOL
1821 tens.extend(
1822 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
1823 )
1824 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001825 elif op["op"] == Op.DIV:
1826 assert (
1827 pCount == 2 and cCount == 0
1828 ), "Op.Div must have 2 placeholders, 0 consts"
1829
1830 placeholders = []
1831
1832 # Two invalid cases for Op.DIV:
1833 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07001834 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001835 while True:
1836 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
1837 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
1838
1839 if (divisor_arr == 0).any():
1840 continue
1841
Kevin Cheng47315e12021-05-13 17:41:28 -07001842 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001843 continue
1844
1845 break
1846
1847 placeholders.append(
1848 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1849 )
1850 placeholders.append(
1851 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1852 )
1853
1854 tens.extend(placeholders)
1855 elif op["op"] == Op.MUL:
1856 assert (
1857 pCount == 2 and cCount == 0
1858 ), "Op.MUL must have 2 placeholders, 0 consts"
1859
1860 if dtypeList[0] == DType.FLOAT:
1861 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
1862 else:
1863 placeholders = []
1864
1865 # Make sure multiply result in int32 range
1866 shift = testArgs[0]
1867 if dtypeList[0] == DType.INT8:
1868 num_bits = 8
1869 elif dtypeList[0] == DType.INT16:
1870 num_bits = 16
1871 elif dtypeList[0] == DType.INT32:
1872 num_bits = 32
1873 else:
1874 raise Exception("OpMul: invalid input dtype")
1875
1876 for idx, shape in enumerate(shapeList[:]):
1877 low = -(2 ** (num_bits - 1))
1878 high = (2 ** (num_bits - 1)) - 1
1879
1880 a_arr = np.int32(
1881 self.rng.integers(low=low, high=high, size=shapeList[0])
1882 )
1883 b_arr = np.int32(
1884 self.rng.integers(low=low, high=high, size=shapeList[1])
1885 )
1886
1887 i = 0
1888 while True:
1889
1890 a_arr_64 = a_arr.astype(np.int64)
1891 b_arr_64 = b_arr.astype(np.int64)
1892
1893 if shift > 0:
1894 rounding = 1 << (shift - 1)
1895 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
1896 else:
1897 result_arr = a_arr_64 * b_arr_64
1898
1899 if (result_arr > -(2 ** 31)).all() and (
1900 result_arr <= ((2 ** 31) - 1)
1901 ).all():
1902 break
1903
1904 i = i + 1
1905 a_arr = a_arr // 2
1906 b_arr = b_arr // 2
1907
1908 placeholders.append(
1909 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1910 )
1911 placeholders.append(
1912 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1913 )
1914
1915 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01001916 elif op["op"] == Op.CONCAT:
1917 count = len(shapeList) - self.args.num_const_inputs_concat
1918 if count < 1:
1919 count = 1
1920 if self.args.num_const_inputs_concat == 0:
1921 count = len(shapeList)
1922
1923 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
1924 tens.extend(
1925 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
1926 )
1927 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001928 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001929 tens.extend(
1930 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
1931 )
1932 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001933
1934 if qgen is not None:
Les Bell30e46802021-07-23 09:43:31 +01001935 qinfo = qgen(self, op, dtype_or_dtypeList)
Eric Kunzee5e26762020-10-13 16:11:07 -07001936 else:
1937 qinfo = None
1938
1939 try:
1940 if qinfo is not None:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001941 resultName = build_fcn(self, op["op"], *tens, *testArgs, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07001942 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001943 resultName = build_fcn(self, op["op"], *tens, *testArgs)
Eric Kunzee5e26762020-10-13 16:11:07 -07001944 except TypeError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001945 print(
1946 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
1947 build_fcn, tens, testArgs
1948 )
1949 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001950 raise e
1951
1952 # Save the serialized test
Kevin Cheng550ccc52021-03-03 11:21:43 -08001953 self.serialize("test")
Eric Kunzee5e26762020-10-13 16:11:07 -07001954
1955 def createDynamicOpLists(self):
1956
1957 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng550ccc52021-03-03 11:21:43 -08001958 KERNELS = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07001959
1960 for k in KERNELS:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001961 testName = "conv2d_{}x{}".format(k[0], k[1])
1962 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
1963 self.TOSA_OP_LIST[testName]["filter"] = k
1964 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001965
Kevin Cheng550ccc52021-03-03 11:21:43 -08001966 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
1967 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1968 "depthwise_conv2d_TEMPLATE"
1969 ].copy()
1970 self.TOSA_OP_LIST[testName]["filter"] = k
1971 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001972
Kevin Cheng550ccc52021-03-03 11:21:43 -08001973 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
1974 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1975 "transpose_conv2d_TEMPLATE"
1976 ].copy()
1977 self.TOSA_OP_LIST[testName]["filter"] = k
1978 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001979
1980 # Delete any templates after having created any dynamic ops
1981 # This is a two-pass operation because it's bad practice to delete
1982 # keys from dictionaries while iterating
1983 keyList = []
1984 for k in self.TOSA_OP_LIST:
1985 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001986 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07001987 keyList.append(k)
1988 continue
1989 except KeyError:
1990 pass
1991
1992 for k in keyList:
1993 del self.TOSA_OP_LIST[k]
1994
1995 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001996 """Fill in default fields for ops if they aren't already specified.
1997 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07001998 for op in self.TOSA_OP_LIST:
1999
2000 # Required fields
2001 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002002 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002003 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002004 raise Exception(
2005 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2006 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002007
2008 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002009 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002010 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002011 raise Exception(
2012 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2013 op
2014 )
2015 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002016
2017 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002018 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002019 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002020 raise Exception(
2021 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2022 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002023
2024 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002025 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002026 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002027 raise Exception(
2028 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2029 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002030
2031 # Put in default rank range, if missing
2032 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002033 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002034 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002035 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002036
2037 # Tensor operator list
2038 # 'op': op name
2039 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002040 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2041 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002042 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2043 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08002044 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002045
Kevin Cheng550ccc52021-03-03 11:21:43 -08002046 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
2047 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002048
Kevin Cheng550ccc52021-03-03 11:21:43 -08002049 TYPE_BOOL = [DType.BOOL]
2050 TYPE_FI32 = [DType.FLOAT, DType.INT32]
2051 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
2052 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002053
Kevin Cheng550ccc52021-03-03 11:21:43 -08002054 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002055
Kevin Cheng989cb052021-04-28 16:29:44 -07002056 TYPE_CONV2D = [
Kevin Chenga9017402021-07-28 17:19:23 -07002057 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002058 [DType.INT8, DType.INT8, DType.INT32],
2059 [DType.INT16, DType.INT8, DType.INT48],
2060 DType.FLOAT,
2061 ]
2062
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002063 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002064
2065 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002066 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002067 "argmax": {
2068 "op": Op.ARGMAX,
2069 "operands": (1, 0),
2070 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2071 "types": TYPE_NARROW_INT_FP,
2072 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002073 "avg_pool2d": {
2074 "op": Op.AVG_POOL2D,
2075 "operands": (1, 0),
2076 "rank": (4, 4),
2077 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
2078 "qgen": TosaQuantGen.qgUnary,
2079 "types": TYPE_NARROW_INT_FP,
2080 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002081 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002082 "conv2d_TEMPLATE": {
2083 "op": Op.CONV2D,
2084 "operands": (1, 2),
2085 "rank": (4, 4),
2086 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D),
2087 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002088 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002089 "template": True,
2090 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002091 # Conv3d TBD
Eric Kunzee5e26762020-10-13 16:11:07 -07002092 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002093 "depthwise_conv2d_TEMPLATE": {
2094 "op": Op.DEPTHWISE_CONV2D,
2095 "operands": (1, 2),
2096 "filter": [1, 1],
2097 "rank": (4, 4),
2098 "build_fcn": (
2099 build_depthwise_conv2d,
2100 TosaTensorGen.tgDepthwiseConv2D,
2101 TosaArgGen.agConv2D,
2102 ),
2103 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002104 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002105 "template": True,
2106 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002107 "fully_connected": {
2108 "op": Op.FULLY_CONNECTED,
2109 "operands": (1, 2),
2110 "rank": (2, 2),
2111 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
2112 "qgen": TosaQuantGen.qgConv,
2113 "types": TYPE_CONV2D,
2114 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002115 "matmul": {
2116 "op": Op.MATMUL,
2117 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002118 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08002119 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
2120 "qgen": TosaQuantGen.qgMatmul,
2121 "types": TYPE_NARROW_INT_FP,
2122 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002123 "max_pool2d": {
2124 "op": Op.MAX_POOL2D,
2125 "operands": (1, 0),
2126 "rank": (4, 4),
2127 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
2128 "types": TYPE_NARROW_INT_FP,
2129 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002130 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002131 "transpose_conv2d_TEMPLATE": {
2132 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002133 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002134 "rank": (4, 4),
2135 "build_fcn": (
2136 build_transpose_conv2d,
2137 TosaTensorGen.tgTransposeConv2D,
2138 TosaArgGen.agTransposeConv2D,
2139 ),
2140 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002141 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002142 "template": True,
2143 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002144 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002145 "clamp": {
2146 "op": Op.CLAMP,
2147 "operands": (1, 0),
2148 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
2149 "types": TYPE_NARROW_INT_FP,
2150 },
2151 "relun": {
2152 "op": Op.RELUN,
2153 "operands": (1, 0),
2154 "build_fcn": (build_relun, TosaTensorGen.tgBasic, None),
2155 "types": TYPE_FI32,
2156 },
2157 "sigmoid": {
2158 "op": Op.SIGMOID,
2159 "operands": (1, 0),
2160 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
2161 "types": TYPE_FP,
2162 },
2163 "tanh": {
2164 "op": Op.TANH,
2165 "operands": (1, 0),
2166 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
2167 "types": TYPE_FP,
2168 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002169 # Elementwise Binary Operators
2170 "add": {
2171 "op": Op.ADD,
2172 "operands": (2, 0),
2173 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2174 "types": TYPE_FI32,
2175 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002176 "arithmetic_right_shift": {
2177 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2178 "operands": (2, 0),
2179 "build_fcn": (
2180 build_arithmetic_right_shift,
2181 TosaTensorGen.tgBroadcastFuzz,
2182 TosaArgGen.agArithmeticRightShift,
2183 ),
2184 "types": TYPE_INT,
2185 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002186 "bitwise_and": {
2187 "op": Op.BITWISE_AND,
2188 "operands": (2, 0),
2189 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2190 "types": TYPE_INT,
2191 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002192 "bitwise_or": {
2193 "op": Op.BITWISE_OR,
2194 "operands": (2, 0),
2195 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2196 "types": TYPE_INT,
2197 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002198 "bitwise_xor": {
2199 "op": Op.BITWISE_XOR,
2200 "operands": (2, 0),
2201 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2202 "types": TYPE_INT,
2203 },
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002204 "div": {
2205 "op": Op.DIV,
2206 "operands": (2, 0),
2207 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2208 "types": [DType.INT32],
2209 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002210 "logical_and": {
2211 "op": Op.LOGICAL_AND,
2212 "operands": (2, 0),
2213 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2214 "types": TYPE_BOOL,
2215 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002216 "logical_left_shift": {
2217 "op": Op.LOGICAL_LEFT_SHIFT,
2218 "operands": (2, 0),
2219 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2220 "types": TYPE_INT,
2221 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002222 "logical_right_shift": {
2223 "op": Op.LOGICAL_RIGHT_SHIFT,
2224 "operands": (2, 0),
2225 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2226 "types": TYPE_INT,
2227 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002228 "logical_or": {
2229 "op": Op.LOGICAL_OR,
2230 "operands": (2, 0),
2231 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2232 "types": TYPE_BOOL,
2233 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002234 "logical_xor": {
2235 "op": Op.LOGICAL_XOR,
2236 "operands": (2, 0),
2237 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2238 "types": TYPE_BOOL,
2239 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002240 "maximum": {
2241 "op": Op.MAXIMUM,
2242 "operands": (2, 0),
2243 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2244 "types": TYPE_FI32,
2245 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002246 "minimum": {
2247 "op": Op.MINIMUM,
2248 "operands": (2, 0),
2249 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2250 "types": TYPE_FI32,
2251 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002252 "mul": {
2253 "op": Op.MUL,
2254 "operands": (2, 0),
2255 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
2256 "types": TYPE_INT_FP,
2257 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002258 "pow": {
2259 "op": Op.POW,
2260 "operands": (2, 0),
2261 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
2262 "types": TYPE_FP,
2263 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002264 "sub": {
2265 "op": Op.SUB,
2266 "operands": (2, 0),
2267 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2268 "types": TYPE_FI32,
2269 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002270 "table": {
2271 "op": Op.TABLE,
2272 # Use the automatic generation functions to create the input array
2273 # but create the table tensor in the build function, as it may be
2274 # a different type from the input
2275 "operands": (1, 0),
2276 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002277 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08002278 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002279 # Elementwise Unary operators
2280 "abs": {
2281 "op": Op.ABS,
2282 "operands": (1, 0),
2283 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2284 "types": TYPE_FI32,
2285 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002286 "bitwise_not": {
2287 "op": Op.BITWISE_NOT,
2288 "operands": (1, 0),
2289 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2290 "types": TYPE_INT,
2291 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002292 "ceil": {
2293 "op": Op.CEIL,
2294 "operands": (1, 0),
2295 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2296 "types": TYPE_FP,
2297 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002298 "clz": {
2299 "op": Op.CLZ,
2300 "operands": (1, 0),
2301 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2302 "types": [DType.INT32],
2303 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002304 "exp": {
2305 "op": Op.EXP,
2306 "operands": (1, 0),
2307 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2308 "types": TYPE_FP,
2309 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002310 "floor": {
2311 "op": Op.FLOOR,
2312 "operands": (1, 0),
2313 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2314 "types": TYPE_FP,
2315 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002316 "log": {
2317 "op": Op.LOG,
2318 "operands": (1, 0),
2319 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2320 "types": TYPE_FP,
2321 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002322 "logical_not": {
2323 "op": Op.LOGICAL_NOT,
2324 "operands": (1, 0),
2325 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2326 "types": TYPE_BOOL,
2327 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002328 "negate": {
2329 "op": Op.NEGATE,
2330 "operands": (1, 0),
2331 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2332 "qgen": TosaQuantGen.qgUnary,
2333 "types": TYPE_INT_FP,
2334 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002335 "reciprocal": {
2336 "op": Op.RECIPROCAL,
2337 "operands": (1, 0),
2338 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2339 "types": TYPE_FP,
2340 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002341 "rsqrt": {
2342 "op": Op.RSQRT,
2343 "operands": (1, 0),
2344 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2345 "types": TYPE_FP,
2346 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002347 # Elementwise Ternary operators
2348 "select": {
2349 "op": Op.SELECT,
2350 "operands": (3, 0),
2351 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
2352 "types": TYPE_FIB,
2353 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002354 # Comparison operators
2355 "equal": {
2356 "op": Op.EQUAL,
2357 "operands": (2, 0),
2358 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2359 "types": TYPE_FI32,
2360 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002361 "greater_equal": {
2362 "op": Op.GREATER_EQUAL,
2363 "operands": (2, 0),
2364 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2365 "types": TYPE_FI32,
2366 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002367 "greater": {
2368 "op": Op.GREATER,
2369 "operands": (2, 0),
2370 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2371 "types": TYPE_FI32,
2372 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002373 # Reduction operators
2374 "reduce_all": {
2375 "op": Op.REDUCE_ALL,
2376 "operands": (1, 0),
2377 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2378 "types": TYPE_BOOL,
2379 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002380 "reduce_any": {
2381 "op": Op.REDUCE_ANY,
2382 "operands": (1, 0),
2383 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2384 "types": TYPE_BOOL,
2385 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002386 "reduce_max": {
2387 "op": Op.REDUCE_MAX,
2388 "operands": (1, 0),
2389 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2390 "types": TYPE_INT_FP,
2391 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002392 "reduce_min": {
2393 "op": Op.REDUCE_MAX,
2394 "operands": (1, 0),
2395 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2396 "types": TYPE_INT_FP,
2397 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002398 "reduce_product": {
2399 "op": Op.REDUCE_PRODUCT,
2400 "operands": (1, 0),
2401 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2402 "types": TYPE_FP,
2403 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002404 "reduce_sum": {
2405 "op": Op.REDUCE_SUM,
2406 "operands": (1, 0),
2407 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2408 "types": TYPE_FI32,
2409 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002410 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002411 "concat": {
2412 "op": Op.CONCAT,
2413 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01002414 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002415 "types": TYPE_FIB,
2416 },
2417 "pad": {
2418 "op": Op.PAD,
2419 "operands": (1, 0),
2420 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
2421 "qgen": TosaQuantGen.qgPad,
2422 "types": TYPE_FIB,
2423 },
2424 "reshape": {
2425 "op": Op.RESHAPE,
2426 "operands": (1, 0),
2427 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
2428 "types": TYPE_FIB,
2429 },
2430 "reverse": {
2431 "op": Op.REVERSE,
2432 "operands": (1, 0),
2433 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2434 "types": TYPE_FIB,
2435 },
2436 "slice": {
2437 "op": Op.SLICE,
2438 "operands": (1, 0),
2439 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
2440 "types": TYPE_FIB,
2441 },
2442 "tile": {
2443 "op": Op.TILE,
2444 "operands": (1, 0),
2445 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
2446 "types": TYPE_FIB,
2447 },
2448 "transpose": {
2449 "op": Op.TRANSPOSE,
2450 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01002451 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002452 "build_fcn": (
2453 build_transpose,
2454 TosaTensorGen.tgBasic,
2455 TosaArgGen.agTranspose,
2456 ),
2457 "types": TYPE_FIB,
2458 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002459 # Data nodes
2460 "const": {
2461 "op": Op.CONST,
2462 "operands": (1, 0),
2463 "build_fcn": (build_placeholder, TosaTensorGen.tgBasic, None),
2464 "types": TYPE_FIB,
2465 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002466 "identity": {
2467 "op": Op.IDENTITY,
2468 "operands": (1, 0),
2469 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2470 "types": TYPE_FIB,
2471 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002472 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08002473 "gather": {
2474 "op": Op.GATHER,
2475 # Only specify 'values' tensor here. 'indices' is generated in op building stage
2476 "operands": (1, 0),
2477 "rank": (3, 3),
2478 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
2479 "types": TYPE_INT_FP,
2480 },
2481 "scatter": {
2482 "op": Op.SCATTER,
2483 # Only specify 'values_in' tensor here.
2484 #'indices' and 'input' are generated in op building stage
2485 "operands": (2, 0),
2486 "rank": (3, 3),
2487 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
2488 "types": TYPE_INT_FP,
2489 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002490 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08002491 "resize": {
2492 "op": Op.RESIZE,
2493 "operands": (1, 0),
2494 "rank": (4, 4),
2495 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
2496 "types": [DType.INT8, DType.INT16, DType.FLOAT],
2497 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002498 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08002499 "cast": {
2500 "op": Op.CAST,
2501 "operands": (1, 0),
2502 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
2503 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
2504 },
2505 "rescale": {
2506 "op": Op.RESCALE,
2507 "operands": (1, 0),
2508 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002509 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002510 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002511 # Custom
2512 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08002513 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07002514 # Two varients of cond_if, one that generates one of two constant tensors (no
2515 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
2516 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002517 "cond_if_const": {
2518 "op": Op.COND_IF,
2519 "operands": (0, 2),
2520 "build_fcn": (
2521 build_cond_if_const,
2522 TosaTensorGen.tgBasic,
2523 TosaArgGen.agCondIf,
2524 ),
2525 "types": [DType.BOOL],
2526 },
2527 "cond_if_binary": {
2528 "op": Op.COND_IF,
2529 "operands": (2, 0),
2530 "build_fcn": (
2531 build_cond_if_binary,
2532 TosaTensorGen.tgBasic,
2533 TosaArgGen.agCondIf,
2534 ),
2535 "types": TYPE_FI32,
2536 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002537 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002538 "while_loop": {
2539 "op": Op.WHILE_LOOP,
2540 "operands": (0, 1),
2541 "build_fcn": (
2542 build_while_loop,
2543 TosaTensorGen.tgBasic,
2544 TosaArgGen.agWhileLoop,
2545 ),
2546 "types": [DType.INT32],
2547 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002548 }
2549
Kevin Cheng550ccc52021-03-03 11:21:43 -08002550
Eric Kunzee5e26762020-10-13 16:11:07 -07002551class OutputShaper:
2552 # Methods in this class compute the expected output shape and datatype
2553 # for common classes of operations
2554 def __init__(self):
2555 pass
2556
2557 # These methods return arguments that can be used for
2558 # creating a new output tensor
2559 @staticmethod
2560 def binaryBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002561 assert len(a.shape) == len(b.shape)
2562 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002563
2564 shape = []
2565 for i in range(len(a.shape)):
2566 if a.shape[i] == 1:
2567 shape.append(b.shape[i])
2568 else:
2569 shape.append(a.shape[i])
2570
Kevin Cheng550ccc52021-03-03 11:21:43 -08002571 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002572
2573 @staticmethod
2574 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002575 assert len(a.shape) == len(b.shape)
2576 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002577
2578 shape = []
2579 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002580 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07002581 shape.append(a.shape[i])
2582
Kevin Cheng550ccc52021-03-03 11:21:43 -08002583 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002584
2585 @staticmethod
2586 def unaryOp(ser, a):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002587 return ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002588
2589 @staticmethod
2590 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002591 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
2592 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002593
2594 shape = []
2595 for i in range(len(a.shape)):
2596 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
2597
Kevin Cheng550ccc52021-03-03 11:21:43 -08002598 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002599
2600 @staticmethod
2601 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002602 assert len(a.shape) == len(b.shape)
2603 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002604
2605 # Do broadcast
2606 shape = []
2607 for i in range(len(a.shape)):
2608 if a.shape[i] == 1:
2609 shape.append(b.shape[i])
2610 else:
2611 shape.append(a.shape[i])
2612
2613 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08002614 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07002615
2616 @staticmethod
2617 def reduceOp(ser, a, axis):
2618
2619 shape = a.shape.copy()
2620
2621 shape[axis] = 1
2622
Kevin Cheng550ccc52021-03-03 11:21:43 -08002623 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002624
2625 @staticmethod
2626 def argmaxOp(ser, a, axis):
2627 shape = a.shape.copy()
2628 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002629 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002630
2631 @staticmethod
2632 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
2633
2634 # IFM: NHWC
2635 # Filter: OHWI
2636 # OFM: NHWC
2637
2638 if len(padding) == 2:
2639 # Expand padding to 4 parameters in the case of transpose_conv2d
2640 # From H,W to T,B,L,R
2641 padding = [padding[0], padding[0], padding[1], padding[1]]
2642
Kevin Cheng550ccc52021-03-03 11:21:43 -08002643 h = (
2644 ifm.shape[1]
2645 - filter.shape[1]
2646 - (filter.shape[1] - 1) * (dilations[0] - 1)
2647 + padding[0]
2648 + padding[1]
2649 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002650
Kevin Cheng550ccc52021-03-03 11:21:43 -08002651 w = (
2652 ifm.shape[2]
2653 - filter.shape[2]
2654 - (filter.shape[2] - 1) * (dilations[1] - 1)
2655 + padding[2]
2656 + padding[3]
2657 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002658
2659 if h <= 0 or w <= 0:
2660 # Invalid test parameters?
2661 h = 0
2662 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002663 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002664
2665 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
2666
Kevin Cheng3a478572021-01-22 17:21:02 -08002667 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002668 out_dtype = DType.INT32
2669 elif ifm.dtype == DType.INT16:
2670 out_dtype = DType.INT48
2671 elif ifm.dtype == DType.FLOAT:
2672 out_dtype = DType.FLOAT
2673 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002674 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002675
Kevin Cheng550ccc52021-03-03 11:21:43 -08002676 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002677
2678 @staticmethod
2679 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
2680 # IFM: NHWC
2681 # Filter: HWCM
2682 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08002683 h = (
2684 ifm.shape[1]
2685 - filter.shape[0]
2686 - (filter.shape[0] - 1) * (dilations[0] - 1)
2687 + padding[0]
2688 + padding[1]
2689 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002690
Kevin Cheng550ccc52021-03-03 11:21:43 -08002691 w = (
2692 ifm.shape[2]
2693 - filter.shape[1]
2694 - (filter.shape[1] - 1) * (dilations[1] - 1)
2695 + padding[2]
2696 + padding[3]
2697 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002698
2699 if h <= 0 or w <= 0:
2700 # Invalid test parameters?
2701 h = 0
2702 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002703 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002704
2705 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
2706
Kevin Cheng3a478572021-01-22 17:21:02 -08002707 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002708 out_dtype = DType.INT32
2709 elif ifm.dtype == DType.INT16:
2710 out_dtype = DType.INT48
2711 elif ifm.dtype == DType.FLOAT:
2712 out_dtype = DType.FLOAT
2713 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002714 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002715
Kevin Cheng550ccc52021-03-03 11:21:43 -08002716 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002717
2718 @staticmethod
2719 def pool2dOp(ser, ifm, kernel, stride, pad):
2720 # input: NHWC
2721 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
2722 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
2723
2724 if h <= 0 or w <= 0:
2725 # Invalid test parameters?
2726 h = 0
2727 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002728 ser.setExpectedFailure(True, "Invalid combination of pooling parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002729
2730 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002731 return ser.addOutput(ofm_shape, ifm.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002732
2733 @staticmethod
2734 def fullyConnectedOp(ser, input, filter):
2735 # input: N, IC
2736 # filter: OC, IC
2737 # output: N, OC
2738
2739 output_shape = [input.shape[0], filter.shape[0]]
2740
Kevin Cheng3a478572021-01-22 17:21:02 -08002741 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002742 out_dtype = DType.INT32
2743 elif input.dtype == DType.INT16:
2744 out_dtype = DType.INT48
2745 elif input.dtype == DType.FLOAT:
2746 out_dtype = DType.FLOAT
2747 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002748 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002749
Kevin Cheng550ccc52021-03-03 11:21:43 -08002750 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002751
2752 @staticmethod
2753 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07002754 # a: N, H, C
2755 # b: N, C, W
2756 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07002757
Kevin Cheng2d60f002021-06-09 14:18:32 -07002758 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002759
Kevin Cheng3a478572021-01-22 17:21:02 -08002760 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002761 out_dtype = DType.INT32
2762 elif a.dtype == DType.INT16:
2763 out_dtype = DType.INT48
2764 elif a.dtype == DType.FLOAT:
2765 out_dtype = DType.FLOAT
2766 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002767 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002768
Kevin Cheng550ccc52021-03-03 11:21:43 -08002769 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002770
2771 @staticmethod
Matthew Haddon818ab902021-07-27 09:12:49 +01002772 def concatOp(ser, axis, *a):
2773 input1 = a[0]
2774 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07002775
Matthew Haddon818ab902021-07-27 09:12:49 +01002776 output_shape = input1.shape.copy()
Eric Kunzee5e26762020-10-13 16:11:07 -07002777
Matthew Haddon818ab902021-07-27 09:12:49 +01002778 output_shape[axis] = input1.shape[axis]
2779
2780 for tensor in remaining_inputs:
2781 output_shape[axis] += tensor.shape[axis]
2782
2783 return ser.addOutput(output_shape, input1.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002784
2785 @staticmethod
2786 def padOp(ser, a, padding):
2787
2788 output_shape = a.shape.copy()
2789
2790 for i in range(len(output_shape)):
2791 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
2792
Kevin Cheng550ccc52021-03-03 11:21:43 -08002793 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002794
2795 @staticmethod
2796 def reshapeOp(ser, a, shape):
2797 output_shape = shape.copy()
2798
2799 totalElements = 1
2800 for i in a.shape:
2801 totalElements *= i
2802
2803 # If there are any -1 elements, figure out what that dimension must be
2804 totalOutputElements = 1
2805 for i in output_shape:
2806 if i != -1:
2807 totalOutputElements *= i
2808
2809 # And fill it in
2810 for i in range(len(output_shape)):
2811 if output_shape[i] == -1:
2812 output_shape[i] = totalElements // totalOutputElements
2813
Kevin Cheng550ccc52021-03-03 11:21:43 -08002814 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002815
2816 @staticmethod
2817 def sliceOp(ser, a, begin, size):
2818
2819 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002820 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002821
2822 @staticmethod
2823 def tileOp(ser, a, multiples):
2824
2825 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002826 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002827
2828 for i in range(len(output_shape)):
2829 output_shape[i] = a.shape[i] * multiples[i]
2830
Kevin Cheng550ccc52021-03-03 11:21:43 -08002831 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002832
2833 @staticmethod
2834 def transposeOp(ser, a, perms):
2835 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002836 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002837
2838 for i in range(len(output_shape)):
2839 output_shape[i] = a.shape[perms[i]]
2840
Kevin Cheng550ccc52021-03-03 11:21:43 -08002841 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002842
2843 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08002844 def gatherOp(ser, values, indices):
2845 assert len(values.shape) == 3
2846 assert len(indices.shape) == 2
2847 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07002848
Kevin Cheng77d0f762020-11-24 10:26:32 -08002849 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
2850
Kevin Cheng550ccc52021-03-03 11:21:43 -08002851 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002852
2853 @staticmethod
2854 def scatterOp(ser, values_in, indices, input):
2855 assert len(values_in.shape) == 3
2856 assert len(indices.shape) == 2
2857 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08002858 assert values_in.shape[0] == indices.shape[0] # N
2859 assert input.shape[1] == indices.shape[1] # W
2860 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08002861
2862 output_shape = values_in.shape
2863
Kevin Cheng550ccc52021-03-03 11:21:43 -08002864 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002865
2866 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002867 def tableOp(ser, input, table_dtype):
2868 # Same shape as the input, but dtype dependent on table dtype
2869 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
2870 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
2871 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002872
2873 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08002874 def resizeOp(
2875 ser,
2876 input,
2877 mode,
2878 stride,
2879 offset,
2880 shift,
2881 stride_fp,
2882 offset_fp,
2883 output_dims,
2884 input_dtype,
2885 output_dtype,
2886 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002887
2888 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
2889
Kevin Cheng77d0f762020-11-24 10:26:32 -08002890 if input_dtype == DType.FLOAT:
2891 if stride_fp[0] <= 0 or stride_fp[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002892 ser.setExpectedFailure(True, "Negative or zero stride")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002893 else:
2894 if stride[0] <= 0 or stride[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002895 ser.setExpectedFailure(True, "Negative or zero stride")
Eric Kunzee5e26762020-10-13 16:11:07 -07002896
Kevin Chengaee1fac2020-11-11 13:54:06 -08002897 if mode == ResizeMode.BILINEAR:
2898 if input_dtype == DType.INT8:
2899 if output_dtype != DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002900 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002901 elif input_dtype == DType.INT16:
2902 if output_dtype != DType.INT48:
Kevin Cheng989cb052021-04-28 16:29:44 -07002903 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002904 elif input_dtype == DType.FLOAT:
2905 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002906 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002907 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002908 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002909
2910 elif mode == ResizeMode.NEAREST:
2911 if input_dtype == DType.INT8:
2912 if output_dtype != DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002913 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002914 elif input_dtype == DType.INT16:
2915 if output_dtype != DType.INT16:
Kevin Cheng989cb052021-04-28 16:29:44 -07002916 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002917 elif input_dtype == DType.FLOAT:
2918 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002919 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002920 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002921 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002922
2923 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002924 ser.setExpectedFailure(true, "Invalid resize mode")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002925
Kevin Cheng550ccc52021-03-03 11:21:43 -08002926 return ser.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002927
2928 @staticmethod
2929 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002930 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002931
2932 @staticmethod
2933 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08002934 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002935 out_dtype = DType.INT32
2936 elif ifm.dtype == DType.INT16:
2937 out_dtype = DType.INT48
2938 elif ifm.dtype == DType.FLOAT:
2939 out_dtype = DType.FLOAT
2940 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002941 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002942
2943 if output_shape[1] <= 0 or output_shape[2] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002944 ser.setExpectedFailure(True, "Negative output shape")
Eric Kunzee5e26762020-10-13 16:11:07 -07002945
Kevin Cheng550ccc52021-03-03 11:21:43 -08002946 return ser.addOutput(output_shape, out_dtype)