blob: 3cd1d694e7c70c75510812a4f5fa149ecb3c8c6c [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
48DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080049Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070050ResizeMode = tosa.ResizeMode.ResizeMode()
51
52class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
54
Eric Kunzee5e26762020-10-13 16:11:07 -070055 def __init__(self):
56 pass
57
58 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010059 def getQinfo(testGen, dtype):
60 if dtype == DType.INT8:
61 return testGen.randInt(-128, 128)
62 if dtype == DType.UINT8:
63 return testGen.randInt(0, 256)
64 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070065
66 @staticmethod
67 def qgUnary(testGen, op, dtype):
68 qinfo = ts.TosaSerializerQuantInfo()
Kevin Chengacb550f2021-06-29 15:32:19 -070069 qinfo.UnaryQuantInfo(
70 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
71 )
Eric Kunzee5e26762020-10-13 16:11:07 -070072 return qinfo
73
74 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010075 def qgConv(testGen, op, dtype_or_dtypeList):
Eric Kunzee5e26762020-10-13 16:11:07 -070076 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010077 if isinstance(dtype_or_dtypeList, list):
78 # a list of [input, weights, accumulator] dtypes
79 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070080 else:
Les Bell30e46802021-07-23 09:43:31 +010081 # an int, [input, weights, accumulator] dtypes are the same
82 dtypeList = [dtype_or_dtypeList] * 3
83 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
84 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
85 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -070086 return qinfo
87
88 @staticmethod
89 def qgMatmul(testGen, op, dtype):
90 qinfo = ts.TosaSerializerQuantInfo()
Kevin Chengacb550f2021-06-29 15:32:19 -070091 qinfo.MatMulQuantInfo(
92 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
93 )
Eric Kunzee5e26762020-10-13 16:11:07 -070094 return qinfo
95
96 @staticmethod
97 def qgPad(testGen, op, dtype):
98 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010099 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700100 return qinfo
101
102 @staticmethod
103 def computeMultiplierAndShift(scaleFp, scale32):
104 # Derived from computeMultiplierAndShiftTosaScale32
105 # Provide a floating-point scaling factor and the scale32 parameter
106 # to compute the multiplier and shift
107
108 if scale32:
109 scaleBits = 31
110 else:
111 scaleBits = 15
112
113 m, shift = math.frexp(scaleFp)
114
115 if scaleFp < 0.0:
116 m = -m
117
118 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800119 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700120
121 if multiplier == (1 << scaleBits):
122 multiplier = multiplier // 2
123 shift = shift + 1
124
125 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100126 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
127
128 # Adjust multiplier such that shift is in allowed value range.
129 if shift == 0:
130 multiplier = multiplier // 4
131 shift = shift + 2
132 elif shift == 1:
133 multiplier = multiplier // 2
134 shift = shift + 1
135 elif shift == 63:
136 multiplier = multiplier * 2
137 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700138
Kevin Cheng550ccc52021-03-03 11:21:43 -0800139 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100140 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700141
142 return multiplier, shift
143
144
Kevin Cheng550ccc52021-03-03 11:21:43 -0800145class TosaTensorGen:
146 """Tensor generators create a shape list for the placeholder and const tensor
147 data operands for the operator. The actual random data is generated separately for each test."""
148
Eric Kunzee5e26762020-10-13 16:11:07 -0700149 def __init__(self):
150 pass
151
152 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100153 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800154 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700155 shape = testGen.makeShape(rank)
156
157 shape_list = []
158 for i in range(pl + const):
159 shape_list.append(shape.copy())
160
161 return shape_list
162
163 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100164 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800165 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700166
Matthew Haddon848efb42021-09-09 12:30:53 +0100167 if error_name != ErrorIf.WrongRank:
168 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700169
170 shape = testGen.makeShape(rank)
171
172 # Constrict the batch size?
173 if testGen.args.max_batch_size:
174 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
175
176 shape_list = []
177 for i in range(pl + const):
178 shape_list.append(shape.copy())
179
180 return shape_list
181
182 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100183 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800184 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800185
Kevin Cheng550ccc52021-03-03 11:21:43 -0800186 assert pl == 2
187 assert const == 0
188 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800189
190 values_in_shape = testGen.makeShape(rank)
191
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100192 # ignore max batch size if target shape is set
193 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800194 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
195
Kevin Cheng550ccc52021-03-03 11:21:43 -0800196 W = testGen.randInt(
197 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
198 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100199 # Constrict W if one dimension is too large to keep tensor size reasonable
200 if max(values_in_shape) > 5000:
201 W = testGen.randInt(0, 16)
202
Kevin Cheng77d0f762020-11-24 10:26:32 -0800203 input_shape = [values_in_shape[0], W, values_in_shape[2]]
204
205 shape_list = []
206 shape_list.append(values_in_shape.copy())
207 shape_list.append(input_shape.copy())
208
209 return shape_list
210
211 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100212 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700213 shape = testGen.makeShape(rank)
214
Kevin Cheng550ccc52021-03-03 11:21:43 -0800215 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700216
217 shape_list = []
218
219 # Choose one of the inputs to broadcast
220 bcast_idx = testGen.randInt(0, pl + const)
221 for i in range(pl + const):
222 shape_bcast = shape.copy()
223
224 # If the chosen input, pick a random index to broadcast
225 if i == bcast_idx:
226 fuzz_idx = testGen.randInt(0, rank)
227 shape_bcast[fuzz_idx] = 1
228
229 shape_list.append(shape_bcast)
230
231 return shape_list
232
233 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100234 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800235 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700236
Kevin Cheng550ccc52021-03-03 11:21:43 -0800237 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700238
239 # IFM dimensions are NHWC
240 ifm_shape = testGen.makeShape(rank)
241
242 # Constrict the batch size?
243 if testGen.args.max_batch_size:
244 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
245
246 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800247 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700248
249 # Generate a random OFM depth
250 ofm_depth = testGen.makeShape(1)[0]
251
252 # The filter dimensions are OHWI
253 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
254
255 # The bias is OC
256 bias_shape = np.asarray([ofm_depth])
257
258 return [ifm_shape, filter_shape, bias_shape]
259
260 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100261 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700262 pl, const = op["operands"]
263
264 assert rank == 5
265
266 # IFM dimensions are NDHWC
267 ifm_shape = testGen.makeShape(rank)
268
269 # Constrict the batch size?
270 if testGen.args.max_batch_size:
271 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
272
273 # Get the filter depth/height/width from the operator parameters
274 filter_dhw = op["filter"]
275
276 # Generate a random OFM channel
277 ofm_channel = testGen.makeShape(1)[0]
278
279 # The filter dimensions are ODHWI
280 filter_shape = np.asarray(
281 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
282 )
283
284 # The bias is OC
285 bias_shape = np.asarray([ofm_channel])
286
287 return [ifm_shape, filter_shape, bias_shape]
288
289 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100290 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800291 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700292
Kevin Cheng550ccc52021-03-03 11:21:43 -0800293 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700294
295 # IFM dimensions are NHWC
296 ifm_shape = testGen.makeShape(rank)
297
298 # Constrict the batch size?
299 if testGen.args.max_batch_size:
300 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
301
302 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800303 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700304
305 # Generate a random OFM depth
306 ofm_depth = testGen.makeShape(1)[0]
307
308 # The filter dimensions are OHWI
309 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
310
Kevin Cheng989cb052021-04-28 16:29:44 -0700311 # The bias is OC
312 bias_shape = np.asarray([ofm_depth])
313
314 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700315
316 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100317 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800318 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700319
Kevin Cheng550ccc52021-03-03 11:21:43 -0800320 assert rank == 4
321 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700322
323 # IFM dimensions are NHWC
324 ifm_shape = testGen.makeShape(rank)
325
326 # Constrict the batch size?
327 if testGen.args.max_batch_size:
328 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
329
330 # Get the filter height/width from the operator parameters
331 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800332 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700333
334 # Generate a random OFM depth, but don't let it get too big because
335 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800336 filter_m = (
337 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
338 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700339
340 # The filter dimensions are HWCM
341 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
342
343 # The bias is M * C
344 bias_shape = np.asarray([ifm_shape[3] * filter_m])
345
346 return [ifm_shape, filter_shape, bias_shape]
347
348 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100349 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800350 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700351
Kevin Cheng550ccc52021-03-03 11:21:43 -0800352 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700353
354 input_shape = testGen.makeShape(rank)
Kevin Chengacb550f2021-06-29 15:32:19 -0700355 filter_oc = testGen.rng.integers(
356 low=testGen.args.tensor_shape_range[0],
357 high=testGen.args.tensor_shape_range[1],
358 size=1,
359 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700360 filter_shape = np.asarray([filter_oc, input_shape[1]])
361
362 bias_shape = np.asarray([filter_oc])
363
364 return [input_shape, filter_shape, bias_shape]
365
366 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100367 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800368 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700369
Kevin Cheng2d60f002021-06-09 14:18:32 -0700370 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800371 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700372
373 a_shape = testGen.makeShape(rank)
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100374 # Get a random number for b_oc even if target shape is defined
375 b_oc = np.int32(
376 testGen.rng.integers(
377 low=testGen.args.tensor_shape_range[0],
378 high=testGen.args.tensor_shape_range[1],
379 size=1,
380 )
381 )[0]
382 # If N or H is large let b_oc be 1 to reduce output tensor size
383 if max(a_shape) > 1000:
384 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700385
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100386 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700387 return [a_shape, b_shape]
388
Matthew Haddon818ab902021-07-27 09:12:49 +0100389 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100390 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100391 pl, const = opName["operands"]
392 shape = testGen.makeShape(rank)
393
394 # Create extra tensors to concat.
395 # Take into account value of pl when getting maximum number of concats
396 num_tensors = testGen.randInt(0, 4)
397 shape_list = []
398 for i in range(pl + const + num_tensors):
399 shape_list.append(shape.copy())
400
401 return shape_list
402
403 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100404 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100405 # Split concat shape along axis to allow for multiple const inputs
406 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100407 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddon818ab902021-07-27 09:12:49 +0100408 return shapeList
409
Jeremy Johnson960985a2021-10-06 10:58:14 +0100410 # Create copy of shape we are going to split (so we don't alter shapeList)
411 shape = shapeList[0].copy()
412 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100413 new_shapeList = [shape.copy()]
414 length_on_axis = shape[axis]
415 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700416 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100417 # Calculate split on axis and remaining value
418 split_shape_val = int(shape[axis] / 2)
419 remaining_length = remaining_length - split_shape_val
420
421 # Append new shape, and set remaining shape
422 shape[axis] = split_shape_val
423 new_shapeList.append(shape.copy())
424 shape[axis] = remaining_length
425 if i == len(shapeList) - 3:
426 new_shapeList.append(shape.copy())
427
428 return new_shapeList
429
430
Eric Kunzee5e26762020-10-13 16:11:07 -0700431class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800432 """Argument generators create exhaustive or random lists of attributes for operators that take
433 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
434 tuples where the descriptive_name is appended to the test name and the arglist is expanded
435 as arguments to the operator build function."""
436
Eric Kunzee5e26762020-10-13 16:11:07 -0700437 def __init__(self):
438 pass
439
440 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100441 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800442 """A trivial argument generator for operators that don't take any
443 non-tensor arguments"""
444 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700445
446 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100447 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800448 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700449 axes = []
450
451 shape = shapeList[0]
452
453 for a in range(0, len(shape)):
Matthew Haddon43e37192021-07-09 14:13:02 +0100454 axes.append(("axis{}".format(a), [a]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700455 return axes
456
457 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100458 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700459 arg_list = []
460
461 ifm_shape = shapeList[0]
462 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100463 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
464 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700465
Les Bell7aa69f42021-09-20 10:44:07 +0100466 # Check the rank
467 rank = 5 if opName.startswith("conv3d") else 4
468 assert len(ifm_shape) == rank
469 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700470
Les Bell7aa69f42021-09-20 10:44:07 +0100471 # kernel rank omits batch and channels
472 k_rank = rank - 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700473
Les Bell7aa69f42021-09-20 10:44:07 +0100474 # Generate comprehensive argument lists
475 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
476 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
477 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
478 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
479 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
480 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700481
Les Bell7aa69f42021-09-20 10:44:07 +0100482 # add some oversize argument values
483 if max(ifm_shape) < 64:
484 bigPadding = 9
485 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
486 bigStride = 8
487 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
488 bigDilation = 7
489 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100490
491 # There are too many parameter combinations, so generate them sparsely
Les Bell7aa69f42021-09-20 10:44:07 +0100492 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
493 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
494 if sparsity < 13:
495 sparsity = 1
496 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
497 sparsity += 1
Les Bellf414b3c2021-09-06 11:29:46 +0100498 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100499 for s in sorted(list(strides)):
500 for p in sorted(list(paddings)):
501 for d in sorted(list(dilations)):
502 if (n % sparsity == 0
503 # padding must not exceed the kernel size ?
504 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
505 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
506 # the padded shape must exceed the kernel size
507 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
508 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
509 # the padded shape must exceed the dilation
510 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
511 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
512 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100513 arg_list.append(
514 (
515 "st{}_pad{}_dilat{}".format(
516 "".join([str(x) for x in s]),
517 "".join([str(x) for x in p]),
518 "".join([str(x) for x in d]),
519 ),
520 [s, p, d],
521 )
522 )
523 n += 1
524
Kevin Cheng1533b852021-09-01 12:51:58 -0700525 return arg_list
526
527 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100528 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700529 arg_list = []
530
531 ifm_shape = shapeList[0]
532 filter_shape = shapeList[1]
533
534 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800535 assert len(ifm_shape) == 4
536 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700537
Les Bell7aa69f42021-09-20 10:44:07 +0100538 # Generate comprehensive argument lists
539 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
540 paddings = {x for x in itertools.product(*([p_vals] * 2))}
541 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
542 strides = {x for x in itertools.product(*([s_vals] * 2))}
543 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
544 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700545
Les Bell7aa69f42021-09-20 10:44:07 +0100546 # add some oversize argument values
547 if max(ifm_shape) < 64:
548 bigPadding = 9
549 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
550 bigStride = 8
551 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
552 bigDilation = 7
553 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700554
Les Bell7aa69f42021-09-20 10:44:07 +0100555 # There are too many parameter combinations, so generate them sparsely
556 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
557 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
558 if sparsity < 13:
559 sparsity = 1
560 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
561 sparsity += 1
562 n = 0
563 for s in sorted(list(strides)):
564 for p in sorted(list(paddings)):
565 for d in sorted(list(dilations)):
566 if n % sparsity == 0:
567 # Determine the output shape
568 oh = (
569 ifm_shape[1]
570 - filter_shape[1]
571 - (filter_shape[1] - 1) * (d[0] - 1)
572 + 2 * p[0]
573 ) // s[0] + 1
574 ow = (
575 ifm_shape[2]
576 - filter_shape[2]
577 - (filter_shape[2] - 1) * (d[1] - 1)
578 + 2 * p[1]
579 ) // s[1] + 1
580 os = [ifm_shape[0], oh, ow, filter_shape[0]]
581 arg_list.append(
582 (
583 "st{}_pad{}_dilat{}_os{}".format(
584 "".join([str(x) for x in s]),
585 "".join([str(x) for x in p]),
586 "".join([str(x) for x in d]),
587 "x".join([str(x) for x in os]),
588 ),
589 [s, p, d, os],
590 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800591 )
Les Bell7aa69f42021-09-20 10:44:07 +0100592 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700593
594 return arg_list
595
596 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100597 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700598 arg_list = []
599 rank = len(shapeList[0])
600
Les Bell7ffccce2021-07-28 15:37:02 +0100601 # Exhaustively test combinations of padding on each side of each dimension
602 # - the range of padding values is defined by pad_min and pad_max
603 # - for padding >9, the name format needs to be more distinctive
604 pad_min, pad_max = 0, 1
605 pad_values = [x for x in range(pad_min, pad_max + 1)]
606 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
607 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700608
Les Bell7ffccce2021-07-28 15:37:02 +0100609 for paddings in shape_pad_values:
610 name = "pad"
611 for r in range(rank):
612 before, after = paddings[r]
613 name = f"{name}{before}{after}"
614 arg_list.append((name, [np.array(paddings)]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700615
616 return arg_list
617
618 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100619 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700620 arg_list = []
621
622 shape = shapeList[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800623 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700624
Les Bell7aa69f42021-09-20 10:44:07 +0100625 # Generate comprehensive argument lists
626 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
627 paddings = {x for x in itertools.product(*([p_vals] * 4))}
628 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
629 strides = {x for x in itertools.product(*([s_vals] * 2))}
630 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
631 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700632
Les Bell7aa69f42021-09-20 10:44:07 +0100633 # add some oversize argument values
634 bigStride = 7
635 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
636 bigKernel = 6
637 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
638 if max(shape) < 64:
639 # padding must be less than the kernel size
640 bigPadding = bigKernel - 1
641 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700642
Les Bell7aa69f42021-09-20 10:44:07 +0100643 # There are too many parameter combinations, so generate them sparsely
644 sparsity = len(paddings) * len(strides) * len(kernels) // 500 + 1
645 n = 0
646 for s in sorted(list(strides)):
647 for p in sorted(list(paddings)):
648 for k in sorted(list(kernels)):
649 if (n % sparsity == 0
650 # padding must not exceed the kernel size
651 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
652 # the padded shape must exceed the kernel size
653 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
654 ):
655 arg_list.append(
656 (
657 "st{}_kern{}_pad{}".format(
658 "".join([str(x) for x in s]),
659 "".join([str(x) for x in k]),
660 "".join([str(x) for x in p]),
661 ),
662 [s, p, k],
663 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800664 )
Les Bell7aa69f42021-09-20 10:44:07 +0100665 n += 1
666
Eric Kunzee5e26762020-10-13 16:11:07 -0700667 return arg_list
668
669 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100670 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700671 arg_list = []
672
673 # Enumerate the output types here
674 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800675 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700676 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800677 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700678 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800679 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700680 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800681 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700682 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800683 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700684 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800685 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700686
687 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800688 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700689
690 return arg_list
691
692 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100693 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700694 arg_list = []
695
696 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100697 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
698 if inDtype == DType.UINT8 and dtype != DType.INT8:
699 # The only output dtype for UINT8 is INT8, skip all other combinations
700 continue
701 if inDtype != DType.INT8 and dtype == DType.UINT8:
702 # The only input dtype for UINT8 is INT8, skip all other combinations
703 continue
704
Kevin Cheng550ccc52021-03-03 11:21:43 -0800705 for scale32 in [False, True]:
706 for double_round in [False, True]:
707 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700708
709 if inDtype == DType.INT48 and scale32:
710 # Illegal condition. Must be scale32=False
711 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100712 if double_round and not scale32:
713 # Illegal condition. ERROR_IF(!scale32 && double_round)
714 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700715
Kevin Cheng550ccc52021-03-03 11:21:43 -0800716 arg_list.append(
717 (
718 "out{}_sc{}_dr{}_pc{}".format(
719 DTypeNames[dtype],
720 int(scale32),
721 int(double_round),
722 int(per_channel),
723 ),
724 [dtype, scale32, double_round, per_channel],
725 )
726 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700727
728 return arg_list
729
Kevin Chengaee1fac2020-11-11 13:54:06 -0800730 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100731 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800732 arg_list = []
733
734 if dtype is DType.INT32:
735 for p in range(testGen.args.num_rand_permutations):
736
737 shift = testGen.randInt(0, 32)
738
Kevin Cheng550ccc52021-03-03 11:21:43 -0800739 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800740 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100741 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800742
743 return arg_list
744
745 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100746 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800747 arg_list = []
748
Kevin Cheng550ccc52021-03-03 11:21:43 -0800749 arg_list.append(("roundTrue", [True]))
750 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800751
752 return arg_list
753
Eric Kunzee5e26762020-10-13 16:11:07 -0700754 # Helper function for reshape. Gets some factors of a larger number.
755 @staticmethod
756 def getFactors(val, start=1):
757 factors = []
758
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100759 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700760 if (val % i) == 0:
761 factors.append(i)
762
763 return factors
764
765 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100766 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700767 arg_list = []
768
769 origShape = shapeList[0]
770
771 totalElements = 1
772 for s in origShape:
773 totalElements *= s
774
775 # This code is NOT fast. Fortunately, the numbers are fairly small.
776 factors = TosaArgGen.getFactors(totalElements)
777
778 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100779 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800780 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700781 continue
782
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100783 found = True
784 # escape_counter breaks while loop if it continues on for too long
785 escape_counter = 0
786 while found:
787 newShape = []
788 # Generate newShape ensuring it isn't a duplicate
789 remainingElements = totalElements
790 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100791 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100792 # pick rank-1 factors
793 newShape.append(shuffledFactors[0])
794 remainingElements = remainingElements // shuffledFactors[0]
795 shuffledFactors = testGen.rng.permutation(
796 TosaArgGen.getFactors(remainingElements)
797 )
798 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700799
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100800 # Toss in a -1 sometimes
801 minusOne = testGen.randInt(0, newRank * 4)
802 if minusOne < newRank:
803 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700804
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100805 # Check for duplicates
806 found = False
807 for name, other_shape in arg_list:
808 if other_shape[0] == newShape:
809 found = True
810 break
811
812 escape_counter += 1
813 if escape_counter >= 100:
814 break
815
816 if not found:
817 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700818
819 return arg_list
820
Eric Kunzee5e26762020-10-13 16:11:07 -0700821 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100822 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700823 arg_list = []
824
825 ifm_shape = shapeList[0]
826
Jeremy Johnsona6185572021-06-21 15:55:35 +0100827 # Get all permutations
828 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700829
Jeremy Johnsona6185572021-06-21 15:55:35 +0100830 # Limit to possible permutations from shape dimension or argument setting
831 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700832
Jeremy Johnsona6185572021-06-21 15:55:35 +0100833 # Get random permutation generator that uses all permutations
834 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700835
Jeremy Johnsona6185572021-06-21 15:55:35 +0100836 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -0700837 arg_list = [
838 ("perm{}".format(p), [random_permutations[p].tolist()])
839 for p in range(limit)
840 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700841 return arg_list
842
843 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100844 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700845 arg_list = []
846
847 ifm_shape = shapeList[0]
848 rank = len(ifm_shape)
849
850 for p in range(testGen.args.num_rand_permutations):
851 begin = []
852 size = []
853
Kevin Cheng550ccc52021-03-03 11:21:43 -0800854 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700855
856 for i in range(rank):
857 if ifm_shape[i] > 1:
858 begin.append(testGen.randInt(0, ifm_shape[i]))
859 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
860
861 # Invalid slice size?
862 if size[i] == 0:
863 valid = False
864 else:
865 begin.append(0)
866 size.append(1)
867
868 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800869 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700870 return arg_list
871
872 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100873 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700874 arg_list = []
875
876 ifm_shape = shapeList[0]
877 rank = len(ifm_shape)
878
879 for p in range(testGen.args.num_rand_permutations):
880
881 # Pick a few random, but small multiple values
882 # because otherwise this has a tendency to generate
883 # enormous tensors
884 multiples = []
885 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +0100886 if ifm_shape[i] > 1000:
887 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
888 multiples.append(1)
889 elif max(ifm_shape) > 1000:
890 multiples.append(2)
891 else:
892 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800893 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700894
895 return arg_list
896
897 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100898 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700899 arg_list = []
900
901 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +0100902 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700903
904 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +0100905 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100906 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +0100907 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800908 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +0100909 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100910 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +0100911 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800912 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800913 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800914 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +0100915 elif error_name == ErrorIf.WrongInputType:
916 # If an incorrect input type is used then we set a 'correct'
917 # output type to avoid other errors
918 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700919 else:
920 continue
921
922 for outputDType in outputDTypeList:
923 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -0700924 # Randomly generate legal output dimensions and shift
925 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +0100926 # A output_dim of 1 will cause offset to exceed allowed range
927 # so minimum value 2 produced below
928 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
929 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
930 output_dims[0] += 1
931 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
932 output_dims[1] += 1
933
Kevin Cheng77d0f762020-11-24 10:26:32 -0800934 in_center_h = (ifm_shape[1] - 1) / 2.0
935 in_center_w = (ifm_shape[2] - 1) / 2.0
936 out_center_h = (output_dims[0] - 1) / 2.0
937 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700938
Kevin Cheng77d0f762020-11-24 10:26:32 -0800939 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
940 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
941 fp_offset_y = in_center_h - fp_stride_y * out_center_h
942 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -0700943
Kevin Cheng77d0f762020-11-24 10:26:32 -0800944 if outputDType == DType.FLOAT:
945 shift = 0
946 stride = [0, 0]
947 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800948 stride_fp = [fp_stride_y, fp_stride_x]
949 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +0100950
951 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +0100952 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +0100953 testGen,
954 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +0100955 mode,
956 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +0100957 shapeList,
958 outputDType,
959 shift,
960 stride,
961 stride_fp,
962 offset,
963 offset_fp
964 )
Matthew Haddon848efb42021-09-09 12:30:53 +0100965 else:
966 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +0100967
Kevin Cheng550ccc52021-03-03 11:21:43 -0800968 arg_list.append(
969 (
970 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +0100971 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800972 output_dims[0],
973 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +0100974 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -0800975 stride_fp[0],
976 stride_fp[1],
977 offset_fp[0],
978 offset_fp[1],
979 ),
980 [
Matthew Haddon848efb42021-09-09 12:30:53 +0100981 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800982 stride,
983 offset,
984 shift,
985 stride_fp,
986 offset_fp,
987 output_dims,
988 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +0100989 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800990 ],
991 )
992 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800993 else:
994 shift = 11
995 unit = float(1 << shift)
996 stride_y = int(round(fp_stride_y * unit))
997 stride_x = int(round(fp_stride_x * unit))
998 offset_y = int(round(fp_offset_y * unit))
999 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001000
Kevin Cheng550ccc52021-03-03 11:21:43 -08001001 while (
Matthew Haddone86fd342021-09-07 16:12:21 +01001002 stride_y >= (16 << shift)
1003 or stride_x >= (16 << shift)
1004 or offset_y >= (16 << shift)
1005 or offset_x >= (16 << shift)
1006 or offset_y <= (-16 << shift)
1007 or offset_x <= (-16 << shift)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001008 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001009 shift = shift - 1
1010 unit = float(1 << shift)
1011 stride_y = int(round(fp_stride_y * unit))
1012 stride_x = int(round(fp_stride_x * unit))
1013 offset_y = int(round(fp_offset_y * unit))
1014 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001015
Kevin Cheng550ccc52021-03-03 11:21:43 -08001016 stride = [stride_y, stride_x]
1017 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001018
1019 stride_fp = [0.0, 0.0]
1020 offset_fp = [0.0, 0.0]
1021
Matthew Haddone86fd342021-09-07 16:12:21 +01001022 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001023 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001024 testGen,
1025 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001026 mode,
1027 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001028 shapeList,
1029 outputDType,
1030 shift,
1031 stride,
1032 stride_fp,
1033 offset,
1034 offset_fp
1035 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001036 else:
1037 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001038
Kevin Cheng550ccc52021-03-03 11:21:43 -08001039 arg_list.append(
1040 (
1041 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001042 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001043 shift,
1044 output_dims[0],
1045 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001046 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001047 stride[0],
1048 stride[1],
1049 offset[0],
1050 offset[1],
1051 ),
1052 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001053 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001054 stride,
1055 offset,
1056 shift,
1057 stride_fp,
1058 offset_fp,
1059 output_dims,
1060 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001061 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001062 ],
1063 )
1064 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001065
1066 return arg_list
1067
Matthew Haddon1c00b712021-10-01 15:51:03 +01001068 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001069 # CondIf generates the condition values here.
1070 # Convert to tensors in the build function, along with the
1071 # then and else blocks
1072 arg_list = []
1073
1074 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001075 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001076
1077 return arg_list
1078
Matthew Haddon1c00b712021-10-01 15:51:03 +01001079 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001080 # While loop: 0 iterations, 1, more than 1
1081 arg_list = []
1082
1083 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001084 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001085
1086 return arg_list
1087
Matthew Haddone86fd342021-09-07 16:12:21 +01001088class TosaErrorIfArgGen:
1089
1090 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001091 def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
Matthew Haddone86fd342021-09-07 16:12:21 +01001092
1093 if outputDType == DType.FLOAT:
1094 if error_name == ErrorIf.StrideSmallerEqualZero:
1095 stride_fp = testGen.rng.random(size=[2]) - 2
1096 elif error_name == ErrorIf.ShiftNotZero:
1097 shift = testGen.rng.integers(1, 5)
1098 elif error_name == ErrorIf.StrideLargerDimension:
1099 shape = shapeList[0]
1100 transform_height = testGen.rng.choice([False, True])
1101 if transform_height:
1102 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1103 else:
1104 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1105 else:
1106 if error_name == ErrorIf.StrideSmallerEqualZero:
1107 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1108 elif error_name == ErrorIf.ShiftSmallerOne:
1109 shift = testGen.rng.integers(-3, 1)
1110 if shift <= 0:
1111 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1112 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1113 else:
1114 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1115 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1116 elif error_name == ErrorIf.ShiftLargerEleven:
1117 shift = np.int16(testGen.rng.integers(12, 15))
1118 elif error_name == ErrorIf.StrideLargerDimension:
1119 shape = shapeList[0]
1120 transform_height = testGen.rng.choice([False, True])
1121 if transform_height:
1122 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1123 else:
1124 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1125 elif error_name == ErrorIf.StrideLargerEqualMax:
1126 stride = [(16 << shift) + 1, (16 << shift) + 1]
1127 elif error_name == ErrorIf.OffsetLargerEqualMax:
1128 offset = [(16 << shift) + 1, (16 << shift) + 1]
1129 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1130 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1131
Matthew Haddon1c00b712021-10-01 15:51:03 +01001132
Matthew Haddon848efb42021-09-09 12:30:53 +01001133 if error_name == ErrorIf.WrongOutputType:
1134 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1135 incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1136 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1137 incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
1138 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1139 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
1140 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1141 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
1142 elif dtype == DType.FLOAT:
1143 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
1144 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001145
Matthew Haddon848efb42021-09-09 12:30:53 +01001146 return shift, stride, stride_fp, offset, offset_fp, outputDType
1147
1148 @staticmethod
1149 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1150 # Mess up input/output tensors for ERROR_IF checks
1151 if error_name == "WrongInputList":
1152 add_input = testGen.rng.choice([True, False])
1153 if add_input:
1154 input_list.append('eiDummyInput')
1155 else:
1156 input_list = input_list[:-1]
1157 if error_name == "WrongOutputList":
1158 add_output = testGen.rng.choice([True, False])
1159 if add_output:
1160 output_list.append('eiDummyOutput')
1161 else:
1162 output_list = []
1163 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001164
1165class TosaErrorValidator:
1166
Matthew Haddon848efb42021-09-09 12:30:53 +01001167 @staticmethod
1168 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
1169 # Check ERROR_IF statements
1170
1171 for val_fcn in validator_fcns:
1172 val_result = val_fcn(True, **kwargs)
1173
1174 validator_name = val_result['error_name']
1175 error_result = val_result['error_result']
1176 error_reason = val_result['error_reason']
1177
1178 if error_result:
1179 if error_name == validator_name:
1180 serializer.setExpectedReturnCode(2, error_reason)
1181 else:
1182 print(f"Multiple ERROR_IF checks hit \nError required: {error_name}, Error_produced: {validator_name}")
1183 return None # Return None to delete test if wrong ERROR_IF is hit
1184 else:
1185 if error_name == validator_name:
1186 print(f"No ERROR_IF hit for {error_name}")
1187 return None
1188
1189 @staticmethod
1190 def evWrongInputType(check=False, **kwargs):
1191 all_dtypes = (DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1192
1193 # Find the unsupported input data types
1194 assert 'op' in kwargs
1195 op = kwargs['op']
1196 input_dtypes = op['types']
1197 wrong_input_dtypes = list(set(all_dtypes) - set(input_dtypes))
1198
1199 error_name = ErrorIf.WrongInputType
1200 param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
1201 error_result = False
1202 error_reason = "Input data type not supported for this operator"
1203
1204 if check:
1205 input_dtype = kwargs['input_dtype']
1206 if input_dtype not in input_dtypes:
1207 error_result = True
1208
1209 info_dict = {
1210 "error_name": error_name,
1211 "error_result": error_result,
1212 "error_reason": error_reason,
1213 "param_reqs": param_reqs
1214 }
1215 return info_dict
1216
1217 @staticmethod
1218 def evWrongOutputType(check=False, **kwargs):
1219 error_name = ErrorIf.WrongOutputType
1220 param_reqs = {"rank": None, "dtype": None, "shape": None}
1221 error_result = False
1222 error_reason = "Output data type not supported for this configuration of operator"
1223
1224 if check:
1225 input_dtype = kwargs['input_dtype']
1226 output_dtype = kwargs['output_dtype']
1227 mode = kwargs['mode']
1228
1229 if (
1230 (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
1231 (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
1232 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1233 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1234 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1235 ):
1236 error_result = True
1237
1238 info_dict = {
1239 "error_name": error_name,
1240 "error_result": error_result,
1241 "error_reason": error_reason,
1242 "param_reqs": param_reqs
1243 }
1244 return info_dict
1245
1246 @staticmethod
1247 def evWrongRank(check=False, **kwargs):
1248 all_ranks = (1, 2, 3, 4, 5)
1249
1250 # Make a list of incorrect ranks
1251 assert 'op' in kwargs
1252 op = kwargs['op']
1253 rmin, rmax = op['rank']
1254 rank_range = range(rmin, rmax + 1)
1255 incorrect_ranks = list(set(all_ranks) - set(rank_range))
1256 # Set minimum incorrect rank to 3 to avoid index error
1257 if op['op'] == Op.RESIZE:
1258 incorrect_ranks = [3, 5]
1259
1260 error_name = ErrorIf.WrongRank
1261 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1262 error_result = False
1263 error_reason = "Rank not supported for this operator"
1264
1265 if check:
1266 input_shape = kwargs['input_shape']
1267 if op['op'] == Op.RESIZE and len(input_shape.shape) != 4:
1268 error_result = True
1269
1270 info_dict = {
1271 "error_name": error_name,
1272 "error_result": error_result,
1273 "error_reason": error_reason,
1274 "param_reqs": param_reqs
1275 }
1276 return info_dict
1277
1278 @staticmethod
1279 def evWrongInputList(check=False, **kwargs):
1280 error_name = ErrorIf.WrongInputList
1281 param_reqs = {"rank": None, "dtype": None, "shape": None}
1282 error_result = False
1283 error_reason = "Op input list does not match expected input"
1284
1285 if check:
1286 op = kwargs['op']
1287 input_list = kwargs['input_list']
1288 num_operands = kwargs['num_operands']
1289 if len(input_list) != num_operands:
1290 error_result = True
1291
1292 info_dict = {
1293 "error_name": error_name,
1294 "error_result": error_result,
1295 "error_reason": error_reason,
1296 "param_reqs": param_reqs
1297 }
1298 return info_dict
1299
1300 @staticmethod
1301 def evWrongOutputList(check=False, **kwargs):
1302 error_name = ErrorIf.WrongOutputList
1303 param_reqs = {"rank": None, "dtype": None, "shape": None}
1304 error_result = False
1305 error_reason = "Op output list does not match expected output"
1306
1307 if check:
1308 output_list = kwargs['output_list']
1309 # Note this will be incorrect if an operator returns more than one output
1310 if len(output_list) != 1:
1311 error_result = True
1312
1313 info_dict = {
1314 "error_name": error_name,
1315 "error_result": error_result,
1316 "error_reason": error_reason,
1317 "param_reqs": param_reqs
1318 }
1319 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01001320
1321 @staticmethod
1322 def evMaxDimExceeded(check=False, **kwargs):
1323 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01001324 param_reqs = {
1325 "rank": [4,4],
1326 "dtype": [DType.INT8],
1327 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
1328 }
Matthew Haddone86fd342021-09-07 16:12:21 +01001329 error_result = False
1330 error_reason = "At least one maximum dimension is larger than 16384"
1331
1332 if check:
1333 input_shape = kwargs['input_shape'].shape
1334 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
1335 if ((input_shape[1] > 16384) or
1336 (input_shape[2] > 16384) or
1337 (output_shape[0] > 16384) or
1338 (output_shape[1] > 16384)):
1339 error_result = True
1340
1341 info_dict = {
1342 "error_name": error_name,
1343 "error_result": error_result,
1344 "error_reason": error_reason,
1345 "param_reqs": param_reqs
1346 }
1347 return info_dict
1348
1349 @staticmethod
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001350 def evBatchMismatch(check=False, **kwargs):
1351 error_name = ErrorIf.BatchMismatch
1352 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1353 error_result = False
1354 error_reason = "Input batch size not equal to output batch size"
1355
1356 assert 'op' in kwargs
1357 op = kwargs['op']
1358 rmin, rmax = op['rank']
1359 rank_range = range(rmin, rmax + 1)
1360
1361 if check:
1362 input_shape = kwargs['input_shape'].shape
1363 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1364
1365 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
1366 error_result = True
1367
1368 info_dict = {
1369 "error_name": error_name,
1370 "error_result": error_result,
1371 "error_reason": error_reason,
1372 "param_reqs": param_reqs
1373 }
1374 return info_dict
1375
1376 @staticmethod
1377 def evChannelMismatch(check=False, **kwargs):
1378 error_name = ErrorIf.ChannelMismatch
1379 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1380 error_result = False
1381 error_reason = "Input channel size not equal to output channel size"
1382
1383 assert 'op' in kwargs
1384 op = kwargs['op']
1385 rmin, rmax = op['rank']
1386 rank_range = range(rmin, rmax + 1)
1387
1388 if check:
1389 input_shape = kwargs['input_shape'].shape
1390 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1391 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
1392 error_result = True
1393
1394 info_dict = {
1395 "error_name": error_name,
1396 "error_result": error_result,
1397 "error_reason": error_reason,
1398 "param_reqs": param_reqs
1399 }
1400 return info_dict
1401
1402 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001403 def evStrideSmallerEqualZero(check=False, **kwargs):
1404 error_name = ErrorIf.StrideSmallerEqualZero
1405 param_reqs = {"rank": None, "dtype": None, "shape": None}
1406 error_result = False
1407 error_reason = "Stride value smaller than or equal zero"
1408
1409 if check:
1410 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001411 output_dtype = kwargs['output_dtype']
1412 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
1413 stride = kwargs['stride'] # Work around wrong input/output type tests
1414 elif output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001415 stride = kwargs['stride_fp']
Matthew Haddon848efb42021-09-09 12:30:53 +01001416 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1417 stride = kwargs['stride_fp'] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01001418 else:
1419 stride = kwargs['stride']
1420
1421 if min(stride) <= 0:
1422 error_result = True
1423
1424 info_dict = {
1425 "error_name": error_name,
1426 "error_result": error_result,
1427 "error_reason": error_reason,
1428 "param_reqs": param_reqs
1429 }
1430 return info_dict
1431
1432 @staticmethod
1433 def evStrideLargerEqualMax(check=False, **kwargs):
1434 error_name = ErrorIf.StrideLargerEqualMax
1435 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1436 error_result = False
1437 error_reason = "Stride value larger than or equal to maximum value"
1438
1439 if check:
1440 shift = kwargs['shift']
1441 input_dtype = kwargs['input_dtype']
1442 stride = kwargs['stride']
1443 if input_dtype in [DType.INT8, DType.INT16]:
1444 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1445 error_result = True
1446 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1447 error_result = True
1448
1449 info_dict = {
1450 "error_name": error_name,
1451 "error_result": error_result,
1452 "error_reason": error_reason,
1453 "param_reqs": param_reqs
1454 }
1455 return info_dict
1456
1457
1458 @staticmethod
1459 def evStrideLargerDimension(check=False, **kwargs):
1460 error_name = ErrorIf.StrideLargerDimension
1461 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1462 error_result = False
1463 error_reason = "Stride value larger than or equal to H/W dimension"
1464
1465 if check:
1466 shape = kwargs['input_shape'].shape
1467 input_dtype = kwargs['input_dtype']
1468 stride = kwargs['stride_fp']
1469
1470 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1471 error_result = True
1472
1473 info_dict = {
1474 "error_name": error_name,
1475 "error_result": error_result,
1476 "error_reason": error_reason,
1477 "param_reqs": param_reqs
1478 }
1479 return info_dict
1480
1481
1482 @staticmethod
1483 def evOffsetSmallerEqualMin(check=False, **kwargs):
1484 error_name = ErrorIf.OffsetSmallerEqualMin
1485 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1486 error_result = False
1487 error_reason = "Offset value smaller than or equal to minimum value"
1488
1489 if check:
1490 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001491 output_dtype = kwargs['output_dtype']
1492 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001493 offset = kwargs['offset_fp']
1494 else:
1495 offset = kwargs['offset']
1496
1497 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
1498 error_result = True
1499 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
1500 error_result = True
1501
1502 info_dict = {
1503 "error_name": error_name,
1504 "error_result": error_result,
1505 "error_reason": error_reason,
1506 "param_reqs": param_reqs
1507 }
1508 return info_dict
1509
1510 @staticmethod
1511 def evOffsetLargerEqualMax(check=False, **kwargs):
1512 error_name = ErrorIf.OffsetLargerEqualMax
1513 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1514 error_result = False
1515 error_reason = "Offset value larger than or equal to maximum value"
1516
1517 if check:
1518 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001519 output_dtype = kwargs['output_dtype']
1520 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001521 offset = kwargs['offset_fp']
1522 else:
1523 offset = kwargs['offset']
1524
1525 if shift >= 0:
1526 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
1527 error_result = True
1528
1529 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
1530 error_result = True
1531 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
1532 error_result = True
1533
1534 info_dict = {
1535 "error_name": error_name,
1536 "error_result": error_result,
1537 "error_reason": error_reason,
1538 "param_reqs": param_reqs
1539 }
1540 return info_dict
1541
1542 @staticmethod
1543 def evShiftNotZero(check=False, **kwargs):
1544 error_name = ErrorIf.ShiftNotZero
1545 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1546 error_result = False
1547 error_reason = "Shift value must be zero for float input"
1548
1549 if check:
1550 shift = kwargs['shift']
1551 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001552 output_dtype = kwargs['output_dtype']
1553 if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
Matthew Haddone86fd342021-09-07 16:12:21 +01001554 error_result = True
1555
1556 info_dict = {
1557 "error_name": error_name,
1558 "error_result": error_result,
1559 "error_reason": error_reason,
1560 "param_reqs": param_reqs
1561 }
1562 return info_dict
1563
1564
1565 @staticmethod
1566 def evShiftSmallerOne(check=False, **kwargs):
1567 error_name = ErrorIf.ShiftSmallerOne
1568 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1569 error_result = False
1570 error_reason = "Shift value smaller than one"
1571
1572 if check:
1573 shift = kwargs['shift']
1574 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001575 output_dtype = kwargs['output_dtype']
1576 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001577 error_result = True
1578
1579 info_dict = {
1580 "error_name": error_name,
1581 "error_result": error_result,
1582 "error_reason": error_reason,
1583 "param_reqs": param_reqs
1584 }
1585 return info_dict
1586
1587 @staticmethod
1588 def evShiftLargerEleven(check=False, **kwargs):
1589 error_name = ErrorIf.ShiftLargerEleven
1590 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1591 error_result = False
1592 error_reason = "Shift value larger than eleven"
1593
1594 if check:
1595 shift = kwargs['shift']
1596 if shift > 11:
1597 error_result = True
1598
1599 info_dict = {
1600 "error_name": error_name,
1601 "error_result": error_result,
1602 "error_reason": error_reason,
1603 "param_reqs": param_reqs
1604 }
1605 return info_dict
1606
1607
Matthew Haddonb724efc2021-08-25 16:40:29 +01001608class TosaInvalidValidator:
1609
1610 @staticmethod
1611 def ivWrongDataTypeOrModeResize(**kwargs):
1612 input_dtype = kwargs["input_dtype"]
1613 args = kwargs["args"]
1614 mode = args[0]
1615 stride = args[1]
1616 stride_fp = args[4]
1617 output_dtype = args[8]
1618
1619 if mode == ResizeMode.BILINEAR:
1620 # Invalid output data type / Invalid input datatype
1621 return (
1622 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
1623 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
1624 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
1625 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
1626 )
1627 elif mode == ResizeMode.NEAREST:
1628 # Invalid output data type / Invalid input datatype
1629 return (
1630 (input_dtype != output_dtype) or
1631 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
1632 )
1633 else:
1634 # Invalid resize mode
1635 return True
1636
1637 @staticmethod
1638 def ivBadStride(**kwargs):
1639 input_dtype = kwargs["input_dtype"]
1640 args = kwargs["args"]
1641 stride_x = args[1][0]
1642 stride_y = args[1][1]
1643 stride_fp_x = args[4][0]
1644 stride_fp_y = args[4][1]
1645
1646 if input_dtype == DType.FLOAT:
1647 if stride_fp_x <= 0 or stride_fp_y <= 0:
1648 # Negative or zero stride
1649 return True
1650 else:
1651 if stride_x <= 0 or stride_y <= 0:
1652 # Negative or zero stride
1653 return True
1654 return False
1655
1656
Matthew Haddonb724efc2021-08-25 16:40:29 +01001657 @staticmethod
1658 def ivHeightWidthSmallerZero(**kwargs):
1659 opName = kwargs['opName']
1660
1661 inputShapes = kwargs['shapeList']
1662 input = inputShapes[0]
1663 if not opName.endswith("pool2d"):
1664 filter = inputShapes[1]
1665
1666 args = kwargs['args']
1667 strides = args[0]
1668 padding = args[1]
1669 dilations = args[2]
1670 if opName.endswith("pool2d"):
1671 kernel = args[2]
1672
1673 if opName.startswith('conv2d'):
1674 h = (
1675 input[1]
1676 - filter[1]
1677 - (filter[1] - 1) * (dilations[0] - 1)
1678 + padding[0]
1679 + padding[1]
1680 ) // strides[0] + 1
1681
1682 w = (
1683 input[2]
1684 - filter[2]
1685 - (filter[2] - 1) * (dilations[1] - 1)
1686 + padding[2]
1687 + padding[3]
1688 ) // strides[1] + 1
1689 elif opName.startswith("depthwise_conv2d"):
1690 h = (
1691 input[1]
1692 - filter[0]
1693 - (filter[0] - 1) * (dilations[0] - 1)
1694 + padding[0]
1695 + padding[1]
1696 ) // strides[0] + 1
1697
1698 w = (
1699 input[2]
1700 - filter[1]
1701 - (filter[1] - 1) * (dilations[1] - 1)
1702 + padding[2]
1703 + padding[3]
1704 ) // strides[1] + 1
1705 elif opName.endswith("pool2d"):
1706 h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
1707 w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
1708 else:
1709 assert False, "Unrecognized Op"
1710
1711 if h <= 0 or w <= 0:
1712 # Invalid parameter combination
1713 return True
1714 return False
1715
1716 @staticmethod
1717 def ivNonPositiveOutputShape(**kwargs):
1718 args = kwargs['args']
1719 output_shape = args[3]
1720 if output_shape[1] <= 0 or output_shape[2] <= 0:
1721 # Negative output shape
1722 return True
1723 return False
1724
1725
Kevin Cheng550ccc52021-03-03 11:21:43 -08001726
Eric Kunzee5e26762020-10-13 16:11:07 -07001727class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001728 # Maximum rank of tensor supported by test generator.
1729 TOSA_TENSOR_MAX_RANK = 6
1730
Eric Kunzee5e26762020-10-13 16:11:07 -07001731 def __init__(self, args):
1732 self.args = args
1733 self.basePath = args.output_dir
1734 self.random_seed = args.random_seed
1735 self.ser = None
1736 self.rng = np.random.default_rng(self.random_seed)
1737 self.createDynamicOpLists()
1738 self.initOpListDefaults()
1739 self.quantGen = TosaQuantGen()
1740 # Force makeShape to do a specific starting shape
1741 self.targetted_shape = None
1742
1743 def createSerializer(self, opName, testPath):
1744 self.testPath = os.path.join(opName, testPath)
1745
1746 fullPath = os.path.join(self.basePath, self.testPath)
1747 os.makedirs(fullPath, exist_ok=True)
1748 self.ser = ts.TosaSerializer(fullPath)
1749
1750 def getSerializer(self):
1751 return self.ser
1752
1753 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001754 with open(
1755 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
1756 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07001757 fd.write(self.ser.serialize())
1758
Kevin Cheng550ccc52021-03-03 11:21:43 -08001759 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
1760 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07001761
Matthew Haddon74567092021-07-16 15:38:20 +01001762 def resetRNG(self, seed=None):
1763 if seed == None:
1764 seed = self.random_seed + 1
1765 self.rng = np.random.default_rng(seed)
1766
Eric Kunzee5e26762020-10-13 16:11:07 -07001767 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07001768 if dtype == DType.BOOL:
1769 np_dt = np.bool
1770 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07001771 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07001772 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07001773 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001774 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001775 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
1776 elif dtype == DType.UINT8:
1777 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001778 elif dtype == DType.INT16:
1779 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
1780 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001781 return np.int32(
1782 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
1783 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001784 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001785 return np.int64(
1786 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
1787 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001788 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001789 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001790 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001791 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001792
Kevin Cheng989cb052021-04-28 16:29:44 -07001793 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07001794 placeholders = []
1795
Kevin Cheng989cb052021-04-28 16:29:44 -07001796 assert len(shape_list) == len(dtype_list)
1797
1798 for idx, shape in enumerate(shape_list):
1799 arr = self.getRandTensor(shape, dtype_list[idx])
1800 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001801
1802 return placeholders
1803
Kevin Cheng989cb052021-04-28 16:29:44 -07001804 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07001805 consts = []
1806
Kevin Cheng989cb052021-04-28 16:29:44 -07001807 assert len(shape_list) == len(dtype_list)
1808
1809 for idx, shape in enumerate(shape_list):
1810 arr = self.getRandTensor(shape, dtype_list[idx])
1811 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001812
1813 return consts
1814
1815 def makeShape(self, rank):
1816 if self.targetted_shape:
1817 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001818 return np.int32(
1819 self.rng.integers(
1820 low=self.args.tensor_shape_range[0],
1821 high=self.args.tensor_shape_range[1],
1822 size=rank,
1823 )
1824 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001825
1826 def setTargetShape(self, shape):
1827 self.targetted_shape = shape
1828
1829 def randInt(self, low=0, high=256):
1830 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
1831
1832 def getRandNumberDType(self, dtype):
1833 if dtype == DType.FLOAT:
1834 return self.rng.random()
1835 elif dtype == DType.BOOL:
1836 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07001837 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07001838 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07001839 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001840 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001841 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07001842 elif dtype == DType.INT16:
1843 low, high = (-32768, 32768)
1844 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001845 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07001846 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001847 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07001848 # Special size
1849 return np.int64(self.rng.integers(low, high, size=1))[0]
1850 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001851 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001852
1853 return np.int32(self.rng.integers(low, high, size=1))[0]
1854
1855 def shapeStr(self, shape):
1856
1857 sStr = []
1858 # Convert to strings
1859 for i in shape:
1860 sStr.append(str(i))
1861
Kevin Cheng550ccc52021-03-03 11:21:43 -08001862 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001863
1864 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07001865 if isinstance(t, list):
1866 assert len(t) >= 2
1867 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001868 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001869 if t == DType.BOOL:
1870 return "b"
1871 elif t == DType.INT4:
1872 return "i4"
1873 elif t == DType.INT8:
1874 return "i8"
1875 elif t == DType.UINT8:
1876 return "u8"
1877 elif t == DType.INT16:
1878 return "i16"
1879 elif t == DType.INT32:
1880 return "i32"
1881 elif t == DType.INT48:
1882 return "i48"
1883 elif t == DType.FLOAT:
1884 return "float"
1885 else:
1886 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001887
1888 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001889 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08001890 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07001891 return 4
1892 elif t == DType.INT8:
1893 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08001894 elif t == DType.UINT8:
1895 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07001896 elif t == DType.INT16:
1897 return 16
1898 elif t == DType.INT32:
1899 return 32
1900 elif t == DType.INT48:
1901 return 48
1902 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001903 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001904
1905 # Argument generators
1906 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
1907 # Where the string descriptor is used to generate the test name and
1908 # The build_fcn_arg_list is expanded and passed to the operator test
1909 # build function
1910
Kevin Cheng550ccc52021-03-03 11:21:43 -08001911 def build_unary(self, op, a, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001912 result_tens = OutputShaper.unaryOp(self.ser, a)
Matthew Haddon848efb42021-09-09 12:30:53 +01001913 # build_placeholder returns an int, ABS/other ops does not
1914 if isinstance(op, int):
1915 self.ser.addOperator(op, [a.name], [result_tens.name], None, qinfo)
1916 else:
1917 self.ser.addOperator(op['op'], [a.name], [result_tens.name], None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07001918 return result_tens
1919
1920 def build_binary_broadcast(self, op, a, b):
1921 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01001922 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001923 return result_tens
1924
1925 def build_binary_nonbroadcast(self, op, a, b):
1926 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01001927 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001928 return result_tens
1929
Kevin Chengaee1fac2020-11-11 13:54:06 -08001930 def build_arithmetic_right_shift(self, op, a, b, round):
1931 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1932
1933 attr = ts.TosaSerializerAttribute()
1934 attr.ArithmeticRightShiftAttribute(round)
1935
Matthew Haddon848efb42021-09-09 12:30:53 +01001936 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08001937 return result_tens
1938
1939 def build_mul(self, op, a, b, shift):
Eric Kunzee5e26762020-10-13 16:11:07 -07001940 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1941
1942 # Special for multiply:
1943 # Force the result to INT32 for INT types
1944 if a.dtype != DType.FLOAT:
1945 result_tens.setDtype(DType.INT32)
1946
Kevin Chengaee1fac2020-11-11 13:54:06 -08001947 attr = ts.TosaSerializerAttribute()
1948 attr.MulAttribute(shift)
1949
Matthew Haddon848efb42021-09-09 12:30:53 +01001950 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001951 return result_tens
1952
1953 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001954 # Constant size depending on type, random values
1955 if a.dtype == DType.INT16:
Kevin Chengacb550f2021-06-29 15:32:19 -07001956 table_dtype = DType.INT16
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001957 table_arr = self.getRandTensor([513], table_dtype)
1958 else:
1959 assert a.dtype == DType.INT8
1960 table_dtype = DType.INT8
1961 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001962
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001963 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
1964 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01001965 self.ser.addOperator(op['op'], [a.name, table_tens.name], [result_tens.name], None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001966
1967 return result_tens
1968
1969 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07001970 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01001971 self.ser.addOperator(op['op'], [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001972 return result_tens
1973
1974 def build_comparison(self, op, a, b):
1975 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01001976 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001977 return result_tens
1978
1979 def build_argmax(self, op, a, axis):
1980 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
1981
1982 attr = ts.TosaSerializerAttribute()
1983 attr.AxisAttribute(axis)
1984
Matthew Haddon848efb42021-09-09 12:30:53 +01001985 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001986 return result_tens
1987
Matthew Haddonb724efc2021-08-25 16:40:29 +01001988 def build_pool2d(self, op, input, stride, pad, kernel, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001989 result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
1990
1991 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07001992 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07001993
Matthew Haddon848efb42021-09-09 12:30:53 +01001994 self.ser.addOperator(op['op'], [input.name], [result_tens.name], attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07001995 return result_tens
1996
1997 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001998 assert len(padding) == 4
1999 result_tens = OutputShaper.conv2dOp(
2000 self.ser, ifm, filter, strides, padding, dilations
2001 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002002
2003 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002004 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07002005
Kevin Cheng550ccc52021-03-03 11:21:43 -08002006 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002007 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002008 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002009 return result_tens
2010
Kevin Cheng1533b852021-09-01 12:51:58 -07002011 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
2012 assert len(padding) == 6
2013 result_tens = OutputShaper.conv3dOp(
2014 self.ser, ifm, filter, strides, padding, dilations
2015 )
2016
2017 attr = ts.TosaSerializerAttribute()
2018 attr.ConvAttribute(padding, strides, dilations)
2019
2020 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002021 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng1533b852021-09-01 12:51:58 -07002022 )
2023 return result_tens
2024
Kevin Cheng550ccc52021-03-03 11:21:43 -08002025 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07002026 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002027 ):
2028 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07002029 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
2030
2031 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002032 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002033
Kevin Cheng550ccc52021-03-03 11:21:43 -08002034 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002035 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002036 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002037 return result_tens
2038
Kevin Cheng550ccc52021-03-03 11:21:43 -08002039 def build_depthwise_conv2d(
2040 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
2041 ):
2042 result_tens = OutputShaper.depthwiseConv2dOp(
2043 self.ser, ifm, filter, strides, padding, dilations
2044 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002045
2046 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002047 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07002048
Kevin Cheng550ccc52021-03-03 11:21:43 -08002049 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002050 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002051 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002052 return result_tens
2053
2054 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
2055 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
2056
Kevin Cheng550ccc52021-03-03 11:21:43 -08002057 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002058 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002059 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002060 return result_tens
2061
2062 def build_matmul(self, op, a, b, qinfo):
2063 result_tens = OutputShaper.matmulOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002064 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002065 return result_tens
2066
2067 def build_reduce(self, op, a, axis):
2068 result_tens = OutputShaper.reduceOp(self.ser, a, axis)
2069
2070 attr = ts.TosaSerializerAttribute()
2071 attr.AxisAttribute(axis)
2072
Matthew Haddon848efb42021-09-09 12:30:53 +01002073 self.ser.addOperator(op['op'], [a.name], result_tens.name, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002074 return result_tens
2075
2076 def build_clamp(self, op, a):
2077 result_tens = OutputShaper.unaryOp(self.ser, a)
2078
2079 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01002080 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07002081
2082 if a.dtype == DType.FLOAT:
2083 attr.ClampAttribute(0, 0, min(v), max(v))
2084 else:
2085 attr.ClampAttribute(min(v), max(v), 0, 0)
2086
Matthew Haddon848efb42021-09-09 12:30:53 +01002087 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002088 return result_tens
2089
2090 def build_leaky_relu(self, op, a):
2091 result_tens = OutputShaper.unaryOp(self.ser, a)
2092 attr = ts.TosaSerializerAttribute()
2093
2094 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
2095
Matthew Haddon848efb42021-09-09 12:30:53 +01002096 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002097 return result_tens
2098
2099 # Needs an additional type/input
2100 def build_prelu(self, op, a):
2101 result_tens = OutputShaper.unaryOp(self.ser, a)
2102
Matthew Haddon848efb42021-09-09 12:30:53 +01002103 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002104 return result_tens
2105
Eric Kunzee5e26762020-10-13 16:11:07 -07002106 def build_sigmoid(self, op, a):
2107 result_tens = OutputShaper.unaryOp(self.ser, a)
Matthew Haddon848efb42021-09-09 12:30:53 +01002108 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002109 return result_tens
2110
2111 def build_tanh(self, op, a):
2112 result_tens = OutputShaper.unaryOp(self.ser, a)
Matthew Haddon848efb42021-09-09 12:30:53 +01002113 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002114 return result_tens
2115
Matthew Haddon818ab902021-07-27 09:12:49 +01002116 def build_concat(self, op, *a):
Kevin Cheng93a16282021-08-31 16:14:03 -07002117 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01002118
2119 # To store variable length list of input tensors we need to store axis along with it
2120 axis = a[-1]
2121 a = a[:-1]
2122
2123 result_tens = OutputShaper.concatOp(self.ser, axis, *a)
Eric Kunzee5e26762020-10-13 16:11:07 -07002124
2125 attr = ts.TosaSerializerAttribute()
2126 attr.AxisAttribute(axis)
2127
Matthew Haddon818ab902021-07-27 09:12:49 +01002128 input_tensor_names = []
2129 for tensor in a:
2130 input_tensor_names.append(tensor.name)
2131
Matthew Haddon848efb42021-09-09 12:30:53 +01002132 self.ser.addOperator(op['op'], input_tensor_names, [result_tens.name], attr)
2133 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002134
2135 def build_pad(self, op, a, padding, qinfo):
2136 result_tens = OutputShaper.padOp(self.ser, a, padding)
2137
2138 # Need to turn the padding array into a TOSA tensor here.
2139 # This is one of the few tensor operands that does not get
2140 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08002141 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07002142
Kevin Cheng550ccc52021-03-03 11:21:43 -08002143 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002144 op['op'], [a.name, padding_tens.name], [result_tens.name], None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002145 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002146 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002147
2148 def build_reshape(self, op, a, newShape):
2149 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
2150
2151 attr = ts.TosaSerializerAttribute()
2152 attr.ReshapeAttribute(newShape)
2153
Matthew Haddon848efb42021-09-09 12:30:53 +01002154 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002155 return result_tens
2156
2157 def build_reverse(self, op, a, axis):
2158 result_tens = OutputShaper.unaryOp(self.ser, a)
2159
2160 attr = ts.TosaSerializerAttribute()
2161 attr.AxisAttribute(axis)
2162
Matthew Haddon848efb42021-09-09 12:30:53 +01002163 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002164 return result_tens
2165
2166 def build_transpose(self, op, a, perms):
2167 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
2168
Kevin Cheng550ccc52021-03-03 11:21:43 -08002169 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07002170
Matthew Haddon848efb42021-09-09 12:30:53 +01002171 self.ser.addOperator(op['op'], [a.name, perms_tens.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002172 return result_tens
2173
2174 def build_slice(self, op, a, begin, size):
2175 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
2176
2177 attr = ts.TosaSerializerAttribute()
2178 attr.SliceAttribute(begin, size)
2179
Matthew Haddon848efb42021-09-09 12:30:53 +01002180 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002181 return result_tens
2182
2183 def build_tile(self, op, a, multiples):
2184 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
2185
2186 attr = ts.TosaSerializerAttribute()
2187 attr.TileAttribute(multiples)
2188
Matthew Haddon848efb42021-09-09 12:30:53 +01002189 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002190 return result_tens
2191
Kevin Cheng77d0f762020-11-24 10:26:32 -08002192 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07002193
2194 # Create a new indicies tensor
2195 # here with data that doesn't exceed the dimensions of the values tensor
2196
Kevin Cheng550ccc52021-03-03 11:21:43 -08002197 K = values.shape[1] # K
2198 W = self.randInt(
2199 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
2200 ) # W
2201 indicies_arr = np.int32(
2202 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
2203 ) # (N, W)
2204 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002205
Kevin Cheng77d0f762020-11-24 10:26:32 -08002206 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07002207
Matthew Haddon848efb42021-09-09 12:30:53 +01002208 self.ser.addOperator(op['op'], [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002209
2210 return result_tens
2211
Kevin Cheng77d0f762020-11-24 10:26:32 -08002212 def build_scatter(self, op, values_in, input):
2213
2214 # Create a new indicies tensor
2215 # here with data that doesn't exceed the dimensions of the values_in tensor
2216
Kevin Cheng550ccc52021-03-03 11:21:43 -08002217 K = values_in.shape[1] # K
2218 W = input.shape[1] # W
2219 indicies_arr = np.int32(
2220 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
2221 ) # (N, W)
2222 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002223
2224 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
2225
Kevin Cheng550ccc52021-03-03 11:21:43 -08002226 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002227 op['op'], [values_in.name, indicies.name, input.name], [result_tens.name]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002228 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08002229
2230 return result_tens
2231
Matthew Haddon848efb42021-09-09 12:30:53 +01002232
Kevin Cheng550ccc52021-03-03 11:21:43 -08002233 def build_resize(
2234 self,
2235 op,
2236 input,
2237 mode,
2238 stride,
2239 offset,
2240 shift,
2241 stride_fp,
2242 offset_fp,
2243 output_dims,
2244 input_dtype,
2245 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01002246 validator_fcns,
2247 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002248 ):
2249 result_tens = OutputShaper.resizeOp(
2250 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002251 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002252 input,
2253 mode,
2254 stride,
2255 offset,
2256 shift,
2257 stride_fp,
2258 offset_fp,
2259 output_dims,
2260 input_dtype,
2261 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01002262 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08002263 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002264
Matthew Haddon848efb42021-09-09 12:30:53 +01002265 # Invalidate Input/Output list for error if checks.
2266 input_list = [input.name]
2267 output_list = [result_tens.name]
2268 pCount, cCount = op["operands"]
2269 num_operands = pCount + cCount
2270 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
Matthew Haddone86fd342021-09-07 16:12:21 +01002271
Matthew Haddon848efb42021-09-09 12:30:53 +01002272 TosaErrorValidator.evValidateErrorIfs(
2273 self.ser,
2274 validator_fcns,
2275 error_name,
2276 op=op,
2277 mode=mode,
2278 shift=shift,
2279 input_dtype=input_dtype,
2280 output_dtype=output_dtype,
2281 input_shape=input,
2282 output_shape=output_dims,
2283 offset=offset,
2284 offset_fp=offset_fp,
2285 stride=stride,
2286 stride_fp=stride_fp,
2287 input_list=input_list,
2288 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002289 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01002290 num_operands=num_operands,
2291 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002292
Eric Kunzee5e26762020-10-13 16:11:07 -07002293 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08002294
Kevin Cheng550ccc52021-03-03 11:21:43 -08002295 attr.ResizeAttribute(
2296 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
2297 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002298
Matthew Haddon848efb42021-09-09 12:30:53 +01002299 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002300 return result_tens
2301
2302 def build_identityn(self, op, val, val2):
2303
Kevin Cheng550ccc52021-03-03 11:21:43 -08002304 result_tens = OutputShaper.unaryOp(self.ser, val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002305 result_tens2 = OutputShaper.unaryOp(self.ser, val2)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002306 self.ser.addOperator(
2307 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2308 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002309 return result_tens
2310
Kevin Cheng17e92022021-10-01 14:33:33 -07002311 def build_const(self, op, val):
2312 self.ser.addOutputTensor(val)
2313 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07002314
2315 # Type Conversion
2316 def build_cast(self, op, val, out_dtype):
2317 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01002318 self.ser.addOperator(op['op'], [val.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002319 return result_tens
2320
2321 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
2322 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
2323
2324 if per_channel:
2325 nc = val.shape[-1]
2326 else:
2327 nc = 1
2328
2329 in_type_width = self.typeWidth(val.dtype)
2330 out_type_width = self.typeWidth(out_dtype)
2331
Kevin Cheng3a478572021-01-22 17:21:02 -08002332 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002333 input_zp = self.randInt(-128, 128)
2334 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002335 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002336 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07002337 in_type_width = in_type_width + 1
2338 else:
2339 input_zp = 0
2340
Kevin Cheng3a478572021-01-22 17:21:02 -08002341 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002342 output_zp = self.randInt(-128, 128)
2343 out_type_width = out_type_width + 1
2344 elif out_dtype == DType.UINT8:
2345 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07002346 out_type_width = out_type_width + 1
2347 else:
2348 output_zp = 0
2349
2350 # Calculate scale based on:
2351 # scale = a *(2^output_width)/(2^input_width))
2352
2353 a = np.float32(self.rng.random(size=[nc]))
2354 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2355
2356 if scale32:
2357 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002358 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002359 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2360 else:
2361 # Cap the scaling at 2^15 - 1 for scale16
2362 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2363
Kevin Cheng550ccc52021-03-03 11:21:43 -08002364 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002365
2366 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2367 shift_arr = np.int32(np.zeros(shape=[nc]))
2368
2369 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002370 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2371 scale_arr[i], scale32
2372 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002373
Kevin Cheng550ccc52021-03-03 11:21:43 -08002374 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07002375
2376 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002377 attr.RescaleAttribute(
2378 input_zp,
2379 output_zp,
2380 multiplier_arr,
2381 shift_arr,
2382 scale32,
2383 double_round,
2384 per_channel,
2385 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002386
Matthew Haddon848efb42021-09-09 12:30:53 +01002387 self.ser.addOperator(op['op'], [val.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002388 return result_tens
2389
2390 def build_cond_if_const(self, op, then_tens, else_tens, cond):
2391 # For cond_if with constants, we're supplied with then/else tensors that we ignore
2392 # (except for the generated shap) and the condition. Build Then/Else blocks
2393 # and fill them with const nodes for the body.
2394
2395 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002396 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07002397
2398 # Make then/else tensors
2399 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01002400 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2401 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002402
2403 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002404 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002405
2406 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002407 then_block = "THEN_BLOCK"
2408 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002409 attr = ts.TosaSerializerAttribute()
2410 attr.CondIfAttribute(then_block, else_block)
2411
2412 # Finally, build the op and the two blocks
Matthew Haddon848efb42021-09-09 12:30:53 +01002413 self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002414
2415 self.ser.startBasicBlock(then_block)
2416 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002417 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002418 self.ser.addOutputTensor(then_tens)
2419
2420 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002421 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002422 self.ser.addOutputTensor(else_tens)
2423
2424 return result_tens
2425
2426 def build_cond_if_binary(self, op, a, b, cond):
2427 # For cond_if with a binary op in the then/else blocks, take a and b and
2428 # alternately add or subtract them based on the condition
2429
2430 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002431 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07002432
Kevin Cheng550ccc52021-03-03 11:21:43 -08002433 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002434
2435 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002436 then_block = "THEN_BLOCK"
2437 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002438 attr = ts.TosaSerializerAttribute()
2439 attr.CondIfAttribute(then_block, else_block)
2440
2441 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002442 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002443 op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002444 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002445
2446 self.ser.startBasicBlock(then_block)
2447 self.ser.addInputTensor(a)
2448 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002449 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002450 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
2451
2452 self.ser.startBasicBlock(else_block)
2453 self.ser.addInputTensor(a)
2454 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002455 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002456 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
2457
2458 return result_tens
2459
2460 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002461 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002462
Kevin Cheng550ccc52021-03-03 11:21:43 -08002463 cond_block = "COND_BLOCK"
2464 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002465
2466 attr = ts.TosaSerializerAttribute()
2467 attr.WhileLoopAttribute(cond_block, body_block)
2468
2469 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002470 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002471 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002472 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002473
2474 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002475 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2476 a_out = self.ser.addIntermediate(a.shape, a.dtype)
2477 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002478
2479 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002480 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002481 op['op'],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002482 [iter.name, a.name, acc.name],
2483 [iter_out.name, a_out.name, acc_out.name],
2484 attr,
2485 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002486 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002487
2488 # COND block (input: iter, output: cond_tens )
2489 self.ser.startBasicBlock(cond_block)
2490 self.ser.addInputTensor(iter)
2491 self.ser.addInputTensor(a)
2492 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002493 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
2494 cond_tens = self.ser.addOutput([], DType.BOOL)
2495 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002496
2497 # BODY block (input: a, acc, iter, output: a, acc, iter)
2498 # Note that local intermediate tensors need to be declared here for the outputs
2499 self.ser.startBasicBlock(body_block)
2500 self.ser.addInputTensor(iter)
2501 self.ser.addInputTensor(a)
2502 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002503 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
2504 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2505 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002506 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2507 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2508 self.ser.addOutputTensor(iter_body_out)
2509 self.ser.addOutputTensor(a)
2510 self.ser.addOutputTensor(acc_body_out)
2511
2512 return acc_out
2513
Matthew Haddon1c00b712021-10-01 15:51:03 +01002514 def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
2515 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2516 default_test_rank_range = range(1, 5)
2517 if not shapeFilter:
2518 shapeFilter = [None]
2519
2520 # Calculate the filters based on what is requested and what the operator allows
2521 rmin, rmax = op["rank"]
2522 if rankFilter is not None:
2523 cleanRankFilter = []
2524 # Ensure rankFilter values are allowed by operator
2525 for rank in rankFilter:
2526 if rank >= rmin and rank <= rmax:
2527 cleanRankFilter.append(rank)
2528 elif rankFilter is None and shapeFilter[0] is None:
2529 cleanRankFilter = []
2530 # Ensure default behaviour is bounded by default range or by operator, whichever is smaller.
2531 rankRange = range(rmin, rmax + 1)
2532 for rank in rankRange:
2533 if rank >= min(default_test_rank_range) and rank <= max(default_test_rank_range):
2534 cleanRankFilter.append(rank)
2535 else:
2536 cleanRankFilter = range(rmin, rmax + 1)
2537
2538 dtypes = op["types"]
2539 if dtypeFilter is not None:
2540 cleanDtypeFilter = []
2541 # Ensure filtered dtypes are allowed by operator
2542 for dtype in dtypeFilter:
2543 if dtype in dtypes:
2544 cleanDtypeFilter.append(dtype)
2545 else:
2546 cleanDtypeFilter = dtypes
2547
2548 if testType == 'positive':
2549 filterDict = {
2550 'shapeFilter': shapeFilter,
2551 'rankFilter': cleanRankFilter,
2552 'dtypeFilter': cleanDtypeFilter
2553 }
2554 return filterDict
2555 elif testType == 'negative':
2556 validator_info = validator(check=False, op=op)
2557 error_arguments = validator_info['param_reqs']
2558
2559 #Set parameters as required
2560 if error_arguments['rank'] != None:
2561 rankFilter = error_arguments['rank']
2562 else:
2563 rankFilter = cleanRankFilter
2564
2565 if error_arguments['dtype'] != None:
2566 dtypeFilter = error_arguments['dtype']
2567 else:
2568 dtypeFilter = cleanDtypeFilter
2569
2570 if error_arguments['shape'] != None:
2571 shapeFilter = error_arguments['shape']
2572 else:
2573 shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small
2574
2575 filterDict = {
2576 'shapeFilter': shapeFilter,
2577 'rankFilter': rankFilter,
2578 'dtypeFilter': dtypeFilter
2579 }
2580 return filterDict
2581
2582
Kevin Cheng550ccc52021-03-03 11:21:43 -08002583 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01002584 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08002585 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002586
2587 try:
2588 op = self.TOSA_OP_LIST[opName]
2589 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002590 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002591
2592 # Initialize a new random number generator
2593 self.rng = np.random.default_rng(self.random_seed)
2594
Kevin Cheng550ccc52021-03-03 11:21:43 -08002595 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002596
Eric Kunzee5e26762020-10-13 16:11:07 -07002597 # Test list consists of a tuple of:
2598 # (opName, testNameStr, dtype, shapeList, argumentsList)
2599 testList = []
Matthew Haddon1c00b712021-10-01 15:51:03 +01002600 if testType == 'negative' and "error_if_validators" in op:
2601 error_if_validators = op["error_if_validators"]
2602 else:
2603 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002604
Matthew Haddon1c00b712021-10-01 15:51:03 +01002605 for validator in error_if_validators:
2606 if validator is not None:
2607 error_name = validator(check=False, op=op)['error_name']
2608 #print("error_name: ", error_name)
2609 else:
2610 error_name = None
2611
2612 filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator)
2613 cleanRankFilter = filterDict['rankFilter']
2614 cleanDtypeFilter = filterDict['dtypeFilter']
2615 cleanShapeFilter = filterDict['shapeFilter']
2616 #print(f"Filters: S {shapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
2617
2618 for r in cleanRankFilter:
Kevin Cheng1533b852021-09-01 12:51:58 -07002619 if opName.startswith("conv3d"):
2620 assert r == 5, "conv3d test must have input rank == 5"
Matthew Haddon1c00b712021-10-01 15:51:03 +01002621 for t in cleanDtypeFilter:
2622 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002623 # Filter out by rank
2624 if shape is not None and len(shape) != r:
2625 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002626 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002627 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002628
Matthew Haddon74567092021-07-16 15:38:20 +01002629 shapeStr = self.shapeStr(shapeList[0])
2630 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002631
Matthew Haddon74567092021-07-16 15:38:20 +01002632 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2633 argList = []
2634 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002635 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002636 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002637 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002638
Matthew Haddon74567092021-07-16 15:38:20 +01002639 for argStr, args in argList:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002640 if testType == 'positive':
2641 if argStr:
2642 testStr = "{}_{}_{}_{}".format(
2643 opName, shapeStr, typeStr, argStr
2644 )
2645 else:
2646 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
2647 elif testType == 'negative':
Matthew Haddone86fd342021-09-07 16:12:21 +01002648 if argStr:
2649 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2650 opName, error_name, shapeStr, typeStr, argStr
2651 )
2652 else:
2653 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002654
2655 testList.append((opName, testStr, t, error_name, shapeList, args))
2656
2657 if testType == 'positive':
2658 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2659 if "invalid_test_validators" in op:
2660 invalid_test_validators = op["invalid_test_validators"]
2661 clean_testList = []
2662 for test in testList:
2663 for validator_fcn in invalid_test_validators:
2664 remove_test = False
2665 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
2666 remove_test = True
2667 if not remove_test:
2668 clean_testList.append(test)
2669 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002670
2671 return testList
2672
Matthew Haddone86fd342021-09-07 16:12:21 +01002673
2674 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07002675 try:
2676 op = self.TOSA_OP_LIST[opName]
2677 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002678 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002679
2680 # Create a serializer
2681 self.createSerializer(opName, testStr)
2682
Kevin Cheng550ccc52021-03-03 11:21:43 -08002683 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002684 if "error_if_validators" in op:
2685 error_if_validators = op["error_if_validators"]
2686 else:
2687 error_if_validators = None
2688
Kevin Cheng550ccc52021-03-03 11:21:43 -08002689 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002690 num_operands = pCount + cCount
2691
2692 if isinstance(dtype_or_dtypeList, list):
2693 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002694 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002695 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002696 else:
2697 dtypeList = [dtype_or_dtypeList] * (num_operands)
2698
Kevin Cheng93a16282021-08-31 16:14:03 -07002699 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002700 assert (
2701 len(shapeList) == num_operands
2702 ), "shapeList length {} must match number of operands {}".format(
2703 len(shapeList), num_operands
2704 )
2705 assert (
2706 len(dtypeList) == num_operands
2707 ), "dtypeList length {} must match number of operands {}".format(
2708 len(dtypeList), num_operands
2709 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002710
2711 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002712 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002713 except KeyError:
2714 qgen = None
2715
2716 # Build the random tensor operands and the test
2717 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002718
Matthew Haddon1c00b712021-10-01 15:51:03 +01002719 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs)
2720
2721 if qgen is not None:
2722 qinfo = qgen(self, op, dtype_or_dtypeList)
2723 else:
2724 qinfo = None
2725
2726 try:
2727 if error_if_validators is None:
2728 if qinfo is not None:
2729 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2730 else:
2731 resultName = build_fcn(self, op, *tens, *testArgs)
2732 else:
2733 if qinfo is not None:
2734 resultName = build_fcn(self, op, *tens, *testArgs, qinfo, error_if_validators, error_name)
2735 else:
2736 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name)
2737 except TypeError as e:
2738 print(
2739 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
2740 build_fcn, tens, testArgs
2741 )
2742 )
2743 raise e
2744
2745 if resultName is None:
2746 print("Invalid ERROR_IF tests created")
2747
2748 # Save the serialized test
2749 self.serialize("test")
2750
2751
2752 def generate_tensors(self, op, dtypeList, shapeList, testArgs):
2753 pCount, cCount = op["operands"]
2754
2755 tens = []
Jeremy Johnsonef509a42021-09-07 13:59:47 +01002756 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32:
2757 # Make sure the operation does not cause value saturation - where
2758 # the number wraps due to limited number of bits to store the answer
2759 assert (
2760 pCount == 2 and cCount == 0
2761 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
2762
2763 placeholders = []
2764 add = (op["op"] == Op.ADD)
2765 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
2766 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
2767 if add:
2768 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
2769 else:
2770 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
2771
2772 # Work out the saturation limits
2773 max_i32 = (1 << 31)-1
2774 min_i32 = -(1 << 31)
2775 max_arr = np.full(shapeList[1], max_i32)
2776 min_arr = np.full(shapeList[1], min_i32)
2777
2778 # Find how much values exceed the maximum/minimums
2779 sat_max_arr = np.maximum(res_arr - max_arr, 0)
2780 sat_min_arr = np.minimum(res_arr - min_arr, 0)
2781
2782 if not add:
2783 # Swap saturation values and negate values as we need to perform opposite operations
2784 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
2785
2786 # Create new array of unsaturated values by clipping values as needed
2787 b_unsat_arr = b_arr
2788 if (sat_max_arr != 0).any():
2789 # Clip values that cause saturation
2790 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
2791 # Reduce axes in unsaturated tensor to match original tensor
2792 for axis, dim in enumerate(b_arr.shape):
2793 if dim != b_unsat_arr.shape[axis]:
2794 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
2795 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
2796
2797 if (sat_min_arr != 0).any():
2798 # Clip values that cause saturation
2799 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
2800 # Reduce axes in unsaturated tensor to match original tensor
2801 for axis, dim in enumerate(b_arr.shape):
2802 if dim != b_unsat_arr.shape[axis]:
2803 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
2804 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
2805
2806 placeholders.append(
2807 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
2808 )
2809 placeholders.append(
2810 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
2811 )
2812
2813 tens.extend(placeholders)
2814 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
2815 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002816 assert (
2817 pCount == 2 and cCount == 0
2818 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08002819
2820 placeholders = []
2821 for idx, shape in enumerate(shapeList[:]):
2822 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07002823 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002824 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07002825 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002826 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07002827 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002828 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
2829 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002830 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002831 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002832 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07002833 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08002834
2835 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01002836 elif op["op"] == Op.SELECT:
2837 # Set datatype of condition tensor to boolean
2838 dtypeList[0] = DType.BOOL
2839 tens.extend(
2840 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
2841 )
2842 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddon459443c2021-08-23 16:43:13 +01002843 elif op["op"] == Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002844 assert (
2845 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01002846 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002847
2848 placeholders = []
2849
Matthew Haddon459443c2021-08-23 16:43:13 +01002850 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002851 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07002852 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002853 while True:
2854 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
2855 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
2856
2857 if (divisor_arr == 0).any():
2858 continue
2859
Kevin Cheng47315e12021-05-13 17:41:28 -07002860 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002861 continue
2862
2863 break
2864
2865 placeholders.append(
2866 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
2867 )
2868 placeholders.append(
2869 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
2870 )
2871
2872 tens.extend(placeholders)
2873 elif op["op"] == Op.MUL:
2874 assert (
2875 pCount == 2 and cCount == 0
2876 ), "Op.MUL must have 2 placeholders, 0 consts"
2877
2878 if dtypeList[0] == DType.FLOAT:
2879 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
2880 else:
2881 placeholders = []
2882
2883 # Make sure multiply result in int32 range
2884 shift = testArgs[0]
2885 if dtypeList[0] == DType.INT8:
2886 num_bits = 8
2887 elif dtypeList[0] == DType.INT16:
2888 num_bits = 16
2889 elif dtypeList[0] == DType.INT32:
2890 num_bits = 32
2891 else:
2892 raise Exception("OpMul: invalid input dtype")
2893
2894 for idx, shape in enumerate(shapeList[:]):
2895 low = -(2 ** (num_bits - 1))
2896 high = (2 ** (num_bits - 1)) - 1
2897
2898 a_arr = np.int32(
2899 self.rng.integers(low=low, high=high, size=shapeList[0])
2900 )
2901 b_arr = np.int32(
2902 self.rng.integers(low=low, high=high, size=shapeList[1])
2903 )
2904
2905 i = 0
2906 while True:
2907
2908 a_arr_64 = a_arr.astype(np.int64)
2909 b_arr_64 = b_arr.astype(np.int64)
2910
2911 if shift > 0:
2912 rounding = 1 << (shift - 1)
2913 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
2914 else:
2915 result_arr = a_arr_64 * b_arr_64
2916
2917 if (result_arr > -(2 ** 31)).all() and (
2918 result_arr <= ((2 ** 31) - 1)
2919 ).all():
2920 break
2921
2922 i = i + 1
2923 a_arr = a_arr // 2
2924 b_arr = b_arr // 2
2925
2926 placeholders.append(
2927 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
2928 )
2929 placeholders.append(
2930 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
2931 )
2932
2933 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01002934 elif op["op"] == Op.CONCAT:
2935 count = len(shapeList) - self.args.num_const_inputs_concat
2936 if count < 1:
2937 count = 1
2938 if self.args.num_const_inputs_concat == 0:
2939 count = len(shapeList)
2940
2941 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
2942 tens.extend(
2943 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
2944 )
2945 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08002946 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002947 tens.extend(
2948 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
2949 )
2950 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002951
Matthew Haddon1c00b712021-10-01 15:51:03 +01002952 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002953
2954 def createDynamicOpLists(self):
2955
2956 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002957 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002958
Kevin Cheng1533b852021-09-01 12:51:58 -07002959 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002960 testName = "conv2d_{}x{}".format(k[0], k[1])
2961 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2962 self.TOSA_OP_LIST[testName]["filter"] = k
2963 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002964
Kevin Cheng550ccc52021-03-03 11:21:43 -08002965 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2966 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2967 "depthwise_conv2d_TEMPLATE"
2968 ].copy()
2969 self.TOSA_OP_LIST[testName]["filter"] = k
2970 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002971
Kevin Cheng550ccc52021-03-03 11:21:43 -08002972 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2973 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2974 "transpose_conv2d_TEMPLATE"
2975 ].copy()
2976 self.TOSA_OP_LIST[testName]["filter"] = k
2977 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002978
Kevin Cheng1533b852021-09-01 12:51:58 -07002979 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2980 for k in KERNELS_3D:
2981 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2982 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2983 self.TOSA_OP_LIST[testName]["filter"] = k
2984 self.TOSA_OP_LIST[testName]["template"] = False
2985
Eric Kunzee5e26762020-10-13 16:11:07 -07002986 # Delete any templates after having created any dynamic ops
2987 # This is a two-pass operation because it's bad practice to delete
2988 # keys from dictionaries while iterating
2989 keyList = []
2990 for k in self.TOSA_OP_LIST:
2991 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002992 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07002993 keyList.append(k)
2994 continue
2995 except KeyError:
2996 pass
2997
2998 for k in keyList:
2999 del self.TOSA_OP_LIST[k]
3000
3001 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003002 """Fill in default fields for ops if they aren't already specified.
3003 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003004 for op in self.TOSA_OP_LIST:
3005
3006 # Required fields
3007 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003008 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003009 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003010 raise Exception(
3011 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3012 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003013
3014 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003015 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003016 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003017 raise Exception(
3018 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3019 op
3020 )
3021 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003022
3023 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003024 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003025 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003026 raise Exception(
3027 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3028 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003029
3030 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003031 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003032 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003033 raise Exception(
3034 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3035 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003036
3037 # Put in default rank range, if missing
3038 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003039 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003040 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003041 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003042
3043 # Tensor operator list
3044 # 'op': op name
3045 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003046 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3047 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003048 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3049 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08003050 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07003051
Kevin Cheng550ccc52021-03-03 11:21:43 -08003052 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
3053 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003054
Kevin Cheng550ccc52021-03-03 11:21:43 -08003055 TYPE_BOOL = [DType.BOOL]
3056 TYPE_FI32 = [DType.FLOAT, DType.INT32]
3057 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
3058 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003059
Kevin Cheng550ccc52021-03-03 11:21:43 -08003060 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07003061
Kevin Cheng1533b852021-09-01 12:51:58 -07003062 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003063 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003064 [DType.INT8, DType.INT8, DType.INT32],
3065 [DType.INT16, DType.INT8, DType.INT48],
3066 DType.FLOAT,
3067 ]
3068
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003069 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003070
3071 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003072 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003073 "argmax": {
3074 "op": Op.ARGMAX,
3075 "operands": (1, 0),
3076 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3077 "types": TYPE_NARROW_INT_FP,
3078 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003079 "avg_pool2d": {
3080 "op": Op.AVG_POOL2D,
3081 "operands": (1, 0),
3082 "rank": (4, 4),
3083 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
3084 "qgen": TosaQuantGen.qgUnary,
3085 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003086 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
Jared Smolens573ecd42021-03-04 15:24:10 -08003087 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003088 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003089 "conv2d_TEMPLATE": {
3090 "op": Op.CONV2D,
3091 "operands": (1, 2),
3092 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01003093 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003094 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003095 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003096 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003097 "template": True,
3098 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003099 # Templated operator. Filled in by createDynamicOpLists
3100 "conv3d_TEMPLATE": {
3101 "op": Op.CONV3D,
3102 "operands": (1, 2),
3103 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01003104 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07003105 "qgen": TosaQuantGen.qgConv,
3106 "types": TYPE_CONV,
3107 "template": True,
3108 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003109 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003110 "depthwise_conv2d_TEMPLATE": {
3111 "op": Op.DEPTHWISE_CONV2D,
3112 "operands": (1, 2),
3113 "filter": [1, 1],
3114 "rank": (4, 4),
3115 "build_fcn": (
3116 build_depthwise_conv2d,
3117 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01003118 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003119 ),
3120 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003121 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003122 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003123 "template": True,
3124 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003125 "fully_connected": {
3126 "op": Op.FULLY_CONNECTED,
3127 "operands": (1, 2),
3128 "rank": (2, 2),
3129 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
3130 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003131 "types": TYPE_CONV,
Jared Smolens573ecd42021-03-04 15:24:10 -08003132 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003133 "matmul": {
3134 "op": Op.MATMUL,
3135 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003136 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08003137 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
3138 "qgen": TosaQuantGen.qgMatmul,
3139 "types": TYPE_NARROW_INT_FP,
3140 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003141 "max_pool2d": {
3142 "op": Op.MAX_POOL2D,
3143 "operands": (1, 0),
3144 "rank": (4, 4),
3145 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
3146 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003147 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
Jared Smolens573ecd42021-03-04 15:24:10 -08003148 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003149 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003150 "transpose_conv2d_TEMPLATE": {
3151 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003152 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003153 "rank": (4, 4),
3154 "build_fcn": (
3155 build_transpose_conv2d,
3156 TosaTensorGen.tgTransposeConv2D,
3157 TosaArgGen.agTransposeConv2D,
3158 ),
3159 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003160 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003161 "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003162 "template": True,
3163 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003164 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003165 "clamp": {
3166 "op": Op.CLAMP,
3167 "operands": (1, 0),
3168 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
3169 "types": TYPE_NARROW_INT_FP,
3170 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003171 "sigmoid": {
3172 "op": Op.SIGMOID,
3173 "operands": (1, 0),
3174 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
3175 "types": TYPE_FP,
3176 },
3177 "tanh": {
3178 "op": Op.TANH,
3179 "operands": (1, 0),
3180 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
3181 "types": TYPE_FP,
3182 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003183 # Elementwise Binary Operators
3184 "add": {
3185 "op": Op.ADD,
3186 "operands": (2, 0),
3187 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3188 "types": TYPE_FI32,
3189 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003190 "arithmetic_right_shift": {
3191 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3192 "operands": (2, 0),
3193 "build_fcn": (
3194 build_arithmetic_right_shift,
3195 TosaTensorGen.tgBroadcastFuzz,
3196 TosaArgGen.agArithmeticRightShift,
3197 ),
3198 "types": TYPE_INT,
3199 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003200 "bitwise_and": {
3201 "op": Op.BITWISE_AND,
3202 "operands": (2, 0),
3203 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3204 "types": TYPE_INT,
3205 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003206 "bitwise_or": {
3207 "op": Op.BITWISE_OR,
3208 "operands": (2, 0),
3209 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3210 "types": TYPE_INT,
3211 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003212 "bitwise_xor": {
3213 "op": Op.BITWISE_XOR,
3214 "operands": (2, 0),
3215 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3216 "types": TYPE_INT,
3217 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003218 "intdiv": {
3219 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003220 "operands": (2, 0),
3221 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3222 "types": [DType.INT32],
3223 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003224 "logical_and": {
3225 "op": Op.LOGICAL_AND,
3226 "operands": (2, 0),
3227 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3228 "types": TYPE_BOOL,
3229 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003230 "logical_left_shift": {
3231 "op": Op.LOGICAL_LEFT_SHIFT,
3232 "operands": (2, 0),
3233 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3234 "types": TYPE_INT,
3235 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003236 "logical_right_shift": {
3237 "op": Op.LOGICAL_RIGHT_SHIFT,
3238 "operands": (2, 0),
3239 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3240 "types": TYPE_INT,
3241 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003242 "logical_or": {
3243 "op": Op.LOGICAL_OR,
3244 "operands": (2, 0),
3245 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3246 "types": TYPE_BOOL,
3247 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003248 "logical_xor": {
3249 "op": Op.LOGICAL_XOR,
3250 "operands": (2, 0),
3251 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3252 "types": TYPE_BOOL,
3253 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003254 "maximum": {
3255 "op": Op.MAXIMUM,
3256 "operands": (2, 0),
3257 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3258 "types": TYPE_FI32,
3259 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003260 "minimum": {
3261 "op": Op.MINIMUM,
3262 "operands": (2, 0),
3263 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3264 "types": TYPE_FI32,
3265 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003266 "mul": {
3267 "op": Op.MUL,
3268 "operands": (2, 0),
3269 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
3270 "types": TYPE_INT_FP,
3271 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003272 "pow": {
3273 "op": Op.POW,
3274 "operands": (2, 0),
3275 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
3276 "types": TYPE_FP,
3277 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003278 "sub": {
3279 "op": Op.SUB,
3280 "operands": (2, 0),
3281 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3282 "types": TYPE_FI32,
3283 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003284 "table": {
3285 "op": Op.TABLE,
3286 # Use the automatic generation functions to create the input array
3287 # but create the table tensor in the build function, as it may be
3288 # a different type from the input
3289 "operands": (1, 0),
3290 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003291 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08003292 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003293 # Elementwise Unary operators
3294 "abs": {
3295 "op": Op.ABS,
3296 "operands": (1, 0),
3297 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3298 "types": TYPE_FI32,
3299 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003300 "bitwise_not": {
3301 "op": Op.BITWISE_NOT,
3302 "operands": (1, 0),
3303 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3304 "types": TYPE_INT,
3305 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003306 "ceil": {
3307 "op": Op.CEIL,
3308 "operands": (1, 0),
3309 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3310 "types": TYPE_FP,
3311 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003312 "clz": {
3313 "op": Op.CLZ,
3314 "operands": (1, 0),
3315 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3316 "types": [DType.INT32],
3317 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003318 "exp": {
3319 "op": Op.EXP,
3320 "operands": (1, 0),
3321 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3322 "types": TYPE_FP,
3323 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003324 "floor": {
3325 "op": Op.FLOOR,
3326 "operands": (1, 0),
3327 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3328 "types": TYPE_FP,
3329 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003330 "log": {
3331 "op": Op.LOG,
3332 "operands": (1, 0),
3333 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3334 "types": TYPE_FP,
3335 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003336 "logical_not": {
3337 "op": Op.LOGICAL_NOT,
3338 "operands": (1, 0),
3339 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3340 "types": TYPE_BOOL,
3341 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003342 "negate": {
3343 "op": Op.NEGATE,
3344 "operands": (1, 0),
3345 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3346 "qgen": TosaQuantGen.qgUnary,
3347 "types": TYPE_INT_FP,
3348 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003349 "reciprocal": {
3350 "op": Op.RECIPROCAL,
3351 "operands": (1, 0),
3352 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3353 "types": TYPE_FP,
3354 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003355 "rsqrt": {
3356 "op": Op.RSQRT,
3357 "operands": (1, 0),
3358 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3359 "types": TYPE_FP,
3360 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003361 # Elementwise Ternary operators
3362 "select": {
3363 "op": Op.SELECT,
3364 "operands": (3, 0),
3365 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
3366 "types": TYPE_FIB,
3367 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003368 # Comparison operators
3369 "equal": {
3370 "op": Op.EQUAL,
3371 "operands": (2, 0),
3372 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3373 "types": TYPE_FI32,
3374 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003375 "greater_equal": {
3376 "op": Op.GREATER_EQUAL,
3377 "operands": (2, 0),
3378 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3379 "types": TYPE_FI32,
3380 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003381 "greater": {
3382 "op": Op.GREATER,
3383 "operands": (2, 0),
3384 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3385 "types": TYPE_FI32,
3386 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003387 # Reduction operators
3388 "reduce_all": {
3389 "op": Op.REDUCE_ALL,
3390 "operands": (1, 0),
3391 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3392 "types": TYPE_BOOL,
3393 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003394 "reduce_any": {
3395 "op": Op.REDUCE_ANY,
3396 "operands": (1, 0),
3397 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3398 "types": TYPE_BOOL,
3399 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003400 "reduce_max": {
3401 "op": Op.REDUCE_MAX,
3402 "operands": (1, 0),
3403 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3404 "types": TYPE_INT_FP,
3405 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003406 "reduce_min": {
3407 "op": Op.REDUCE_MAX,
3408 "operands": (1, 0),
3409 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3410 "types": TYPE_INT_FP,
3411 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003412 "reduce_product": {
3413 "op": Op.REDUCE_PRODUCT,
3414 "operands": (1, 0),
3415 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3416 "types": TYPE_FP,
3417 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003418 "reduce_sum": {
3419 "op": Op.REDUCE_SUM,
3420 "operands": (1, 0),
3421 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3422 "types": TYPE_FI32,
3423 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003424 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003425 "concat": {
3426 "op": Op.CONCAT,
3427 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01003428 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003429 "types": TYPE_FIB,
3430 },
3431 "pad": {
3432 "op": Op.PAD,
3433 "operands": (1, 0),
3434 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
3435 "qgen": TosaQuantGen.qgPad,
3436 "types": TYPE_FIB,
3437 },
3438 "reshape": {
3439 "op": Op.RESHAPE,
3440 "operands": (1, 0),
3441 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
3442 "types": TYPE_FIB,
3443 },
3444 "reverse": {
3445 "op": Op.REVERSE,
3446 "operands": (1, 0),
3447 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3448 "types": TYPE_FIB,
3449 },
3450 "slice": {
3451 "op": Op.SLICE,
3452 "operands": (1, 0),
3453 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
3454 "types": TYPE_FIB,
3455 },
3456 "tile": {
3457 "op": Op.TILE,
3458 "operands": (1, 0),
3459 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
3460 "types": TYPE_FIB,
3461 },
3462 "transpose": {
3463 "op": Op.TRANSPOSE,
3464 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003465 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003466 "build_fcn": (
3467 build_transpose,
3468 TosaTensorGen.tgBasic,
3469 TosaArgGen.agTranspose,
3470 ),
3471 "types": TYPE_FIB,
3472 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003473 # Data nodes
3474 "const": {
3475 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003476 "operands": (0, 1),
3477 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08003478 "types": TYPE_FIB,
3479 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003480 "identity": {
3481 "op": Op.IDENTITY,
3482 "operands": (1, 0),
3483 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3484 "types": TYPE_FIB,
3485 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003486 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003487 "gather": {
3488 "op": Op.GATHER,
3489 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3490 "operands": (1, 0),
3491 "rank": (3, 3),
3492 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
3493 "types": TYPE_INT_FP,
3494 },
3495 "scatter": {
3496 "op": Op.SCATTER,
3497 # Only specify 'values_in' tensor here.
3498 #'indices' and 'input' are generated in op building stage
3499 "operands": (2, 0),
3500 "rank": (3, 3),
3501 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
3502 "types": TYPE_INT_FP,
3503 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003504 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003505 "resize": {
3506 "op": Op.RESIZE,
3507 "operands": (1, 0),
3508 "rank": (4, 4),
3509 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
3510 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01003511 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
3512 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
3513 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
Matthew Haddon848efb42021-09-09 12:30:53 +01003514 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003515 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
3516 TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003517 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003518 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003519 "cast": {
3520 "op": Op.CAST,
3521 "operands": (1, 0),
3522 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
3523 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
3524 },
3525 "rescale": {
3526 "op": Op.RESCALE,
3527 "operands": (1, 0),
3528 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003529 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Kevin Cheng550ccc52021-03-03 11:21:43 -08003530 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003531 # Custom
3532 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003533 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003534 # Two varients of cond_if, one that generates one of two constant tensors (no
3535 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3536 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003537 "cond_if_const": {
3538 "op": Op.COND_IF,
3539 "operands": (0, 2),
3540 "build_fcn": (
3541 build_cond_if_const,
3542 TosaTensorGen.tgBasic,
3543 TosaArgGen.agCondIf,
3544 ),
3545 "types": [DType.BOOL],
3546 },
3547 "cond_if_binary": {
3548 "op": Op.COND_IF,
3549 "operands": (2, 0),
3550 "build_fcn": (
3551 build_cond_if_binary,
3552 TosaTensorGen.tgBasic,
3553 TosaArgGen.agCondIf,
3554 ),
3555 "types": TYPE_FI32,
3556 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003557 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003558 "while_loop": {
3559 "op": Op.WHILE_LOOP,
3560 "operands": (0, 1),
3561 "build_fcn": (
3562 build_while_loop,
3563 TosaTensorGen.tgBasic,
3564 TosaArgGen.agWhileLoop,
3565 ),
3566 "types": [DType.INT32],
3567 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003568 }
3569
Kevin Cheng550ccc52021-03-03 11:21:43 -08003570
Eric Kunzee5e26762020-10-13 16:11:07 -07003571class OutputShaper:
3572 # Methods in this class compute the expected output shape and datatype
3573 # for common classes of operations
3574 def __init__(self):
3575 pass
3576
3577 # These methods return arguments that can be used for
3578 # creating a new output tensor
3579 @staticmethod
3580 def binaryBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003581 assert len(a.shape) == len(b.shape)
3582 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003583
3584 shape = []
3585 for i in range(len(a.shape)):
3586 if a.shape[i] == 1:
3587 shape.append(b.shape[i])
3588 else:
3589 shape.append(a.shape[i])
3590
Kevin Cheng550ccc52021-03-03 11:21:43 -08003591 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003592
3593 @staticmethod
3594 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003595 assert len(a.shape) == len(b.shape)
3596 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003597
3598 shape = []
3599 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003600 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07003601 shape.append(a.shape[i])
3602
Kevin Cheng550ccc52021-03-03 11:21:43 -08003603 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003604
3605 @staticmethod
3606 def unaryOp(ser, a):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003607 return ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003608
3609 @staticmethod
3610 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003611 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
3612 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003613
3614 shape = []
3615 for i in range(len(a.shape)):
3616 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
3617
Kevin Cheng550ccc52021-03-03 11:21:43 -08003618 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003619
3620 @staticmethod
3621 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003622 assert len(a.shape) == len(b.shape)
3623 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003624
3625 # Do broadcast
3626 shape = []
3627 for i in range(len(a.shape)):
3628 if a.shape[i] == 1:
3629 shape.append(b.shape[i])
3630 else:
3631 shape.append(a.shape[i])
3632
3633 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08003634 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07003635
3636 @staticmethod
3637 def reduceOp(ser, a, axis):
3638
3639 shape = a.shape.copy()
3640
3641 shape[axis] = 1
3642
Kevin Cheng550ccc52021-03-03 11:21:43 -08003643 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003644
3645 @staticmethod
3646 def argmaxOp(ser, a, axis):
3647 shape = a.shape.copy()
3648 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003649 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07003650
3651 @staticmethod
3652 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
3653
3654 # IFM: NHWC
3655 # Filter: OHWI
3656 # OFM: NHWC
3657
3658 if len(padding) == 2:
3659 # Expand padding to 4 parameters in the case of transpose_conv2d
3660 # From H,W to T,B,L,R
3661 padding = [padding[0], padding[0], padding[1], padding[1]]
3662
Kevin Cheng550ccc52021-03-03 11:21:43 -08003663 h = (
3664 ifm.shape[1]
3665 - filter.shape[1]
3666 - (filter.shape[1] - 1) * (dilations[0] - 1)
3667 + padding[0]
3668 + padding[1]
3669 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003670
Kevin Cheng550ccc52021-03-03 11:21:43 -08003671 w = (
3672 ifm.shape[2]
3673 - filter.shape[2]
3674 - (filter.shape[2] - 1) * (dilations[1] - 1)
3675 + padding[2]
3676 + padding[3]
3677 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003678
Eric Kunzee5e26762020-10-13 16:11:07 -07003679 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
3680
Kevin Cheng3a478572021-01-22 17:21:02 -08003681 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003682 out_dtype = DType.INT32
3683 elif ifm.dtype == DType.INT16:
3684 out_dtype = DType.INT48
3685 elif ifm.dtype == DType.FLOAT:
3686 out_dtype = DType.FLOAT
3687 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003688 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003689
Kevin Cheng550ccc52021-03-03 11:21:43 -08003690 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003691
3692 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -07003693 def conv3dOp(ser, ifm, filter, strides, padding, dilations):
3694
3695 # IFM: NDHWC
3696 # Filter: ODHWI
3697 # OFM: NDHWC
3698
3699 d = (
3700 ifm.shape[1]
3701 - filter.shape[1]
3702 - (filter.shape[1] - 1) * (dilations[0] - 1)
3703 + padding[0]
3704 + padding[1]
3705 ) // strides[0] + 1
3706
3707 h = (
3708 ifm.shape[2]
3709 - filter.shape[2]
3710 - (filter.shape[2] - 1) * (dilations[1] - 1)
3711 + padding[2]
3712 + padding[3]
3713 ) // strides[1] + 1
3714
3715 w = (
3716 ifm.shape[3]
3717 - filter.shape[3]
3718 - (filter.shape[3] - 1) * (dilations[2] - 1)
3719 + padding[4]
3720 + padding[5]
3721 ) // strides[2] + 1
3722
3723 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
3724
3725 if ifm.dtype == DType.INT8:
3726 out_dtype = DType.INT32
3727 elif ifm.dtype == DType.INT16:
3728 out_dtype = DType.INT48
3729 elif ifm.dtype == DType.FLOAT:
3730 out_dtype = DType.FLOAT
3731 else:
3732 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
3733
3734 return ser.addOutput(ofm_shape, out_dtype)
3735
3736 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -07003737 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
3738 # IFM: NHWC
3739 # Filter: HWCM
3740 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08003741 h = (
3742 ifm.shape[1]
3743 - filter.shape[0]
3744 - (filter.shape[0] - 1) * (dilations[0] - 1)
3745 + padding[0]
3746 + padding[1]
3747 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003748
Kevin Cheng550ccc52021-03-03 11:21:43 -08003749 w = (
3750 ifm.shape[2]
3751 - filter.shape[1]
3752 - (filter.shape[1] - 1) * (dilations[1] - 1)
3753 + padding[2]
3754 + padding[3]
3755 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003756
Eric Kunzee5e26762020-10-13 16:11:07 -07003757 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
3758
Kevin Cheng3a478572021-01-22 17:21:02 -08003759 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003760 out_dtype = DType.INT32
3761 elif ifm.dtype == DType.INT16:
3762 out_dtype = DType.INT48
3763 elif ifm.dtype == DType.FLOAT:
3764 out_dtype = DType.FLOAT
3765 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003766 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003767
Kevin Cheng550ccc52021-03-03 11:21:43 -08003768 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003769
3770 @staticmethod
3771 def pool2dOp(ser, ifm, kernel, stride, pad):
3772 # input: NHWC
3773 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
3774 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
3775
Eric Kunzee5e26762020-10-13 16:11:07 -07003776 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003777 return ser.addOutput(ofm_shape, ifm.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003778
3779 @staticmethod
3780 def fullyConnectedOp(ser, input, filter):
3781 # input: N, IC
3782 # filter: OC, IC
3783 # output: N, OC
3784
3785 output_shape = [input.shape[0], filter.shape[0]]
3786
Kevin Cheng3a478572021-01-22 17:21:02 -08003787 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003788 out_dtype = DType.INT32
3789 elif input.dtype == DType.INT16:
3790 out_dtype = DType.INT48
3791 elif input.dtype == DType.FLOAT:
3792 out_dtype = DType.FLOAT
3793 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003794 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003795
Kevin Cheng550ccc52021-03-03 11:21:43 -08003796 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003797
3798 @staticmethod
3799 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07003800 # a: N, H, C
3801 # b: N, C, W
3802 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07003803
Kevin Cheng2d60f002021-06-09 14:18:32 -07003804 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003805
Kevin Cheng3a478572021-01-22 17:21:02 -08003806 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003807 out_dtype = DType.INT32
3808 elif a.dtype == DType.INT16:
3809 out_dtype = DType.INT48
3810 elif a.dtype == DType.FLOAT:
3811 out_dtype = DType.FLOAT
3812 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003813 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003814
Kevin Cheng550ccc52021-03-03 11:21:43 -08003815 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003816
3817 @staticmethod
Matthew Haddon818ab902021-07-27 09:12:49 +01003818 def concatOp(ser, axis, *a):
3819 input1 = a[0]
3820 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07003821
Matthew Haddon818ab902021-07-27 09:12:49 +01003822 output_shape = input1.shape.copy()
Eric Kunzee5e26762020-10-13 16:11:07 -07003823
Matthew Haddon818ab902021-07-27 09:12:49 +01003824 output_shape[axis] = input1.shape[axis]
3825
3826 for tensor in remaining_inputs:
3827 output_shape[axis] += tensor.shape[axis]
3828
3829 return ser.addOutput(output_shape, input1.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003830
3831 @staticmethod
3832 def padOp(ser, a, padding):
3833
3834 output_shape = a.shape.copy()
3835
3836 for i in range(len(output_shape)):
3837 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
3838
Kevin Cheng550ccc52021-03-03 11:21:43 -08003839 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003840
3841 @staticmethod
3842 def reshapeOp(ser, a, shape):
3843 output_shape = shape.copy()
3844
3845 totalElements = 1
3846 for i in a.shape:
3847 totalElements *= i
3848
3849 # If there are any -1 elements, figure out what that dimension must be
3850 totalOutputElements = 1
3851 for i in output_shape:
3852 if i != -1:
3853 totalOutputElements *= i
3854
3855 # And fill it in
3856 for i in range(len(output_shape)):
3857 if output_shape[i] == -1:
3858 output_shape[i] = totalElements // totalOutputElements
3859
Kevin Cheng550ccc52021-03-03 11:21:43 -08003860 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003861
3862 @staticmethod
3863 def sliceOp(ser, a, begin, size):
3864
3865 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003866 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003867
3868 @staticmethod
3869 def tileOp(ser, a, multiples):
3870
3871 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003872 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003873
3874 for i in range(len(output_shape)):
3875 output_shape[i] = a.shape[i] * multiples[i]
3876
Kevin Cheng550ccc52021-03-03 11:21:43 -08003877 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003878
3879 @staticmethod
3880 def transposeOp(ser, a, perms):
3881 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003882 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003883
3884 for i in range(len(output_shape)):
3885 output_shape[i] = a.shape[perms[i]]
3886
Kevin Cheng550ccc52021-03-03 11:21:43 -08003887 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003888
3889 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08003890 def gatherOp(ser, values, indices):
3891 assert len(values.shape) == 3
3892 assert len(indices.shape) == 2
3893 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07003894
Kevin Cheng77d0f762020-11-24 10:26:32 -08003895 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
3896
Kevin Cheng550ccc52021-03-03 11:21:43 -08003897 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08003898
3899 @staticmethod
3900 def scatterOp(ser, values_in, indices, input):
3901 assert len(values_in.shape) == 3
3902 assert len(indices.shape) == 2
3903 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08003904 assert values_in.shape[0] == indices.shape[0] # N
3905 assert input.shape[1] == indices.shape[1] # W
3906 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08003907
3908 output_shape = values_in.shape
3909
Kevin Cheng550ccc52021-03-03 11:21:43 -08003910 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003911
3912 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003913 def tableOp(ser, input, table_dtype):
3914 # Same shape as the input, but dtype dependent on table dtype
3915 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
3916 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
3917 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003918
3919 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08003920 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003921 serializer,
3922 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003923 input,
3924 mode,
3925 stride,
3926 offset,
3927 shift,
3928 stride_fp,
3929 offset_fp,
3930 output_dims,
3931 input_dtype,
3932 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01003933 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08003934 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01003935 if error_name == ErrorIf.WrongRank:
3936 output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
3937 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003938 if error_name == ErrorIf.BatchMismatch:
3939 output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
3940 elif error_name == ErrorIf.ChannelMismatch:
3941 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
3942 else:
3943 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003944
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003945 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003946
3947 @staticmethod
3948 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003949 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003950
3951 @staticmethod
3952 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08003953 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003954 out_dtype = DType.INT32
3955 elif ifm.dtype == DType.INT16:
3956 out_dtype = DType.INT48
3957 elif ifm.dtype == DType.FLOAT:
3958 out_dtype = DType.FLOAT
3959 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003960 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003961
Kevin Cheng550ccc52021-03-03 11:21:43 -08003962 return ser.addOutput(output_shape, out_dtype)