blob: 2c131724a2bb48ba66f89a28a4e902db7d201e50 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
48DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080049Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070050ResizeMode = tosa.ResizeMode.ResizeMode()
51
52class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
54
Eric Kunzee5e26762020-10-13 16:11:07 -070055 def __init__(self):
56 pass
57
58 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010059 def getQinfo(testGen, dtype):
60 if dtype == DType.INT8:
61 return testGen.randInt(-128, 128)
62 if dtype == DType.UINT8:
63 return testGen.randInt(0, 256)
64 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070065
66 @staticmethod
67 def qgUnary(testGen, op, dtype):
68 qinfo = ts.TosaSerializerQuantInfo()
Kevin Chengacb550f2021-06-29 15:32:19 -070069 qinfo.UnaryQuantInfo(
70 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
71 )
Eric Kunzee5e26762020-10-13 16:11:07 -070072 return qinfo
73
74 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010075 def qgConv(testGen, op, dtype_or_dtypeList):
Eric Kunzee5e26762020-10-13 16:11:07 -070076 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010077 if isinstance(dtype_or_dtypeList, list):
78 # a list of [input, weights, accumulator] dtypes
79 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070080 else:
Les Bell30e46802021-07-23 09:43:31 +010081 # an int, [input, weights, accumulator] dtypes are the same
82 dtypeList = [dtype_or_dtypeList] * 3
83 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
84 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
85 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -070086 return qinfo
87
88 @staticmethod
89 def qgMatmul(testGen, op, dtype):
90 qinfo = ts.TosaSerializerQuantInfo()
Kevin Chengacb550f2021-06-29 15:32:19 -070091 qinfo.MatMulQuantInfo(
92 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
93 )
Eric Kunzee5e26762020-10-13 16:11:07 -070094 return qinfo
95
96 @staticmethod
97 def qgPad(testGen, op, dtype):
98 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010099 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700100 return qinfo
101
102 @staticmethod
103 def computeMultiplierAndShift(scaleFp, scale32):
104 # Derived from computeMultiplierAndShiftTosaScale32
105 # Provide a floating-point scaling factor and the scale32 parameter
106 # to compute the multiplier and shift
107
108 if scale32:
109 scaleBits = 31
110 else:
111 scaleBits = 15
112
113 m, shift = math.frexp(scaleFp)
114
115 if scaleFp < 0.0:
116 m = -m
117
118 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800119 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700120
121 if multiplier == (1 << scaleBits):
122 multiplier = multiplier // 2
123 shift = shift + 1
124
125 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100126 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
127
128 # Adjust multiplier such that shift is in allowed value range.
129 if shift == 0:
130 multiplier = multiplier // 4
131 shift = shift + 2
132 elif shift == 1:
133 multiplier = multiplier // 2
134 shift = shift + 1
135 elif shift == 63:
136 multiplier = multiplier * 2
137 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700138
Kevin Cheng550ccc52021-03-03 11:21:43 -0800139 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100140 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700141
142 return multiplier, shift
143
144
Kevin Cheng550ccc52021-03-03 11:21:43 -0800145class TosaTensorGen:
146 """Tensor generators create a shape list for the placeholder and const tensor
147 data operands for the operator. The actual random data is generated separately for each test."""
148
Eric Kunzee5e26762020-10-13 16:11:07 -0700149 def __init__(self):
150 pass
151
152 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100153 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800154 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700155 shape = testGen.makeShape(rank)
156
157 shape_list = []
158 for i in range(pl + const):
159 shape_list.append(shape.copy())
160
161 return shape_list
162
163 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100164 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800165 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700166
Matthew Haddon848efb42021-09-09 12:30:53 +0100167 if error_name != ErrorIf.WrongRank:
168 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700169
170 shape = testGen.makeShape(rank)
171
172 # Constrict the batch size?
173 if testGen.args.max_batch_size:
174 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
175
176 shape_list = []
177 for i in range(pl + const):
178 shape_list.append(shape.copy())
179
180 return shape_list
181
182 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100183 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800184 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800185
Kevin Cheng550ccc52021-03-03 11:21:43 -0800186 assert pl == 2
187 assert const == 0
188 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800189
190 values_in_shape = testGen.makeShape(rank)
191
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100192 # ignore max batch size if target shape is set
193 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800194 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
195
Kevin Cheng550ccc52021-03-03 11:21:43 -0800196 W = testGen.randInt(
197 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
198 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100199 # Constrict W if one dimension is too large to keep tensor size reasonable
200 if max(values_in_shape) > 5000:
201 W = testGen.randInt(0, 16)
202
Kevin Cheng77d0f762020-11-24 10:26:32 -0800203 input_shape = [values_in_shape[0], W, values_in_shape[2]]
204
205 shape_list = []
206 shape_list.append(values_in_shape.copy())
207 shape_list.append(input_shape.copy())
208
209 return shape_list
210
211 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100212 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700213 shape = testGen.makeShape(rank)
214
Kevin Cheng550ccc52021-03-03 11:21:43 -0800215 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700216
217 shape_list = []
218
219 # Choose one of the inputs to broadcast
220 bcast_idx = testGen.randInt(0, pl + const)
221 for i in range(pl + const):
222 shape_bcast = shape.copy()
223
224 # If the chosen input, pick a random index to broadcast
225 if i == bcast_idx:
226 fuzz_idx = testGen.randInt(0, rank)
227 shape_bcast[fuzz_idx] = 1
228
229 shape_list.append(shape_bcast)
230
231 return shape_list
232
233 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100234 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800235 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700236
Kevin Cheng550ccc52021-03-03 11:21:43 -0800237 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700238
239 # IFM dimensions are NHWC
240 ifm_shape = testGen.makeShape(rank)
241
242 # Constrict the batch size?
243 if testGen.args.max_batch_size:
244 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
245
246 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800247 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700248
249 # Generate a random OFM depth
250 ofm_depth = testGen.makeShape(1)[0]
251
252 # The filter dimensions are OHWI
253 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
254
255 # The bias is OC
256 bias_shape = np.asarray([ofm_depth])
257
258 return [ifm_shape, filter_shape, bias_shape]
259
260 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100261 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700262 pl, const = op["operands"]
263
264 assert rank == 5
265
266 # IFM dimensions are NDHWC
267 ifm_shape = testGen.makeShape(rank)
268
269 # Constrict the batch size?
270 if testGen.args.max_batch_size:
271 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
272
273 # Get the filter depth/height/width from the operator parameters
274 filter_dhw = op["filter"]
275
276 # Generate a random OFM channel
277 ofm_channel = testGen.makeShape(1)[0]
278
279 # The filter dimensions are ODHWI
280 filter_shape = np.asarray(
281 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
282 )
283
284 # The bias is OC
285 bias_shape = np.asarray([ofm_channel])
286
287 return [ifm_shape, filter_shape, bias_shape]
288
289 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100290 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800291 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700292
Kevin Cheng550ccc52021-03-03 11:21:43 -0800293 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700294
295 # IFM dimensions are NHWC
296 ifm_shape = testGen.makeShape(rank)
297
298 # Constrict the batch size?
299 if testGen.args.max_batch_size:
300 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
301
302 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800303 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700304
305 # Generate a random OFM depth
306 ofm_depth = testGen.makeShape(1)[0]
307
308 # The filter dimensions are OHWI
309 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
310
Kevin Cheng989cb052021-04-28 16:29:44 -0700311 # The bias is OC
312 bias_shape = np.asarray([ofm_depth])
313
314 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700315
316 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100317 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800318 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700319
Kevin Cheng550ccc52021-03-03 11:21:43 -0800320 assert rank == 4
321 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700322
323 # IFM dimensions are NHWC
324 ifm_shape = testGen.makeShape(rank)
325
326 # Constrict the batch size?
327 if testGen.args.max_batch_size:
328 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
329
330 # Get the filter height/width from the operator parameters
331 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800332 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700333
334 # Generate a random OFM depth, but don't let it get too big because
335 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800336 filter_m = (
337 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
338 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700339
340 # The filter dimensions are HWCM
341 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
342
343 # The bias is M * C
344 bias_shape = np.asarray([ifm_shape[3] * filter_m])
345
346 return [ifm_shape, filter_shape, bias_shape]
347
348 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100349 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800350 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700351
Kevin Cheng550ccc52021-03-03 11:21:43 -0800352 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700353
354 input_shape = testGen.makeShape(rank)
Kevin Chengacb550f2021-06-29 15:32:19 -0700355 filter_oc = testGen.rng.integers(
356 low=testGen.args.tensor_shape_range[0],
357 high=testGen.args.tensor_shape_range[1],
358 size=1,
359 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700360 filter_shape = np.asarray([filter_oc, input_shape[1]])
361
362 bias_shape = np.asarray([filter_oc])
363
364 return [input_shape, filter_shape, bias_shape]
365
366 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100367 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800368 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700369
Kevin Cheng2d60f002021-06-09 14:18:32 -0700370 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800371 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700372
373 a_shape = testGen.makeShape(rank)
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100374 # Get a random number for b_oc even if target shape is defined
375 b_oc = np.int32(
376 testGen.rng.integers(
377 low=testGen.args.tensor_shape_range[0],
378 high=testGen.args.tensor_shape_range[1],
379 size=1,
380 )
381 )[0]
382 # If N or H is large let b_oc be 1 to reduce output tensor size
383 if max(a_shape) > 1000:
384 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700385
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100386 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700387 return [a_shape, b_shape]
388
Matthew Haddon818ab902021-07-27 09:12:49 +0100389 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100390 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100391 pl, const = opName["operands"]
392 shape = testGen.makeShape(rank)
393
394 # Create extra tensors to concat.
395 # Take into account value of pl when getting maximum number of concats
396 num_tensors = testGen.randInt(0, 4)
397 shape_list = []
398 for i in range(pl + const + num_tensors):
399 shape_list.append(shape.copy())
400
401 return shape_list
402
403 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100404 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100405 # Split concat shape along axis to allow for multiple const inputs
406 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100407 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddon818ab902021-07-27 09:12:49 +0100408 return shapeList
409
Jeremy Johnson960985a2021-10-06 10:58:14 +0100410 # Create copy of shape we are going to split (so we don't alter shapeList)
411 shape = shapeList[0].copy()
412 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100413 new_shapeList = [shape.copy()]
414 length_on_axis = shape[axis]
415 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700416 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100417 # Calculate split on axis and remaining value
418 split_shape_val = int(shape[axis] / 2)
419 remaining_length = remaining_length - split_shape_val
420
421 # Append new shape, and set remaining shape
422 shape[axis] = split_shape_val
423 new_shapeList.append(shape.copy())
424 shape[axis] = remaining_length
425 if i == len(shapeList) - 3:
426 new_shapeList.append(shape.copy())
427
428 return new_shapeList
429
430
Eric Kunzee5e26762020-10-13 16:11:07 -0700431class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800432 """Argument generators create exhaustive or random lists of attributes for operators that take
433 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
434 tuples where the descriptive_name is appended to the test name and the arglist is expanded
435 as arguments to the operator build function."""
436
Eric Kunzee5e26762020-10-13 16:11:07 -0700437 def __init__(self):
438 pass
439
440 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100441 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800442 """A trivial argument generator for operators that don't take any
443 non-tensor arguments"""
444 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700445
446 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100447 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800448 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700449 axes = []
450
451 shape = shapeList[0]
452
453 for a in range(0, len(shape)):
Matthew Haddon43e37192021-07-09 14:13:02 +0100454 axes.append(("axis{}".format(a), [a]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700455 return axes
456
457 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100458 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700459 arg_list = []
460
461 ifm_shape = shapeList[0]
462 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100463 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
464 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700465
Les Bell7aa69f42021-09-20 10:44:07 +0100466 # Check the rank
467 rank = 5 if opName.startswith("conv3d") else 4
468 assert len(ifm_shape) == rank
469 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700470
Les Bell7aa69f42021-09-20 10:44:07 +0100471 # kernel rank omits batch and channels
472 k_rank = rank - 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700473
Les Bell7aa69f42021-09-20 10:44:07 +0100474 # Generate comprehensive argument lists
475 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
476 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
477 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
478 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
479 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
480 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700481
Les Bell7aa69f42021-09-20 10:44:07 +0100482 # add some oversize argument values
483 if max(ifm_shape) < 64:
484 bigPadding = 9
485 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
486 bigStride = 8
487 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
488 bigDilation = 7
489 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100490
491 # There are too many parameter combinations, so generate them sparsely
Les Bell7aa69f42021-09-20 10:44:07 +0100492 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
493 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
494 if sparsity < 13:
495 sparsity = 1
496 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
497 sparsity += 1
Les Bellf414b3c2021-09-06 11:29:46 +0100498 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100499 for s in sorted(list(strides)):
500 for p in sorted(list(paddings)):
501 for d in sorted(list(dilations)):
502 if (n % sparsity == 0
503 # padding must not exceed the kernel size ?
504 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
505 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
506 # the padded shape must exceed the kernel size
507 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
508 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
509 # the padded shape must exceed the dilation
510 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
511 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
512 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100513 arg_list.append(
514 (
515 "st{}_pad{}_dilat{}".format(
516 "".join([str(x) for x in s]),
517 "".join([str(x) for x in p]),
518 "".join([str(x) for x in d]),
519 ),
520 [s, p, d],
521 )
522 )
523 n += 1
524
Kevin Cheng1533b852021-09-01 12:51:58 -0700525 return arg_list
526
527 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100528 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700529 arg_list = []
530
531 ifm_shape = shapeList[0]
532 filter_shape = shapeList[1]
533
534 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800535 assert len(ifm_shape) == 4
536 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700537
Les Bell7aa69f42021-09-20 10:44:07 +0100538 # Generate comprehensive argument lists
539 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
540 paddings = {x for x in itertools.product(*([p_vals] * 2))}
541 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
542 strides = {x for x in itertools.product(*([s_vals] * 2))}
543 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
544 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700545
Les Bell7aa69f42021-09-20 10:44:07 +0100546 # add some oversize argument values
547 if max(ifm_shape) < 64:
548 bigPadding = 9
549 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
550 bigStride = 8
551 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
552 bigDilation = 7
553 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700554
Les Bell7aa69f42021-09-20 10:44:07 +0100555 # There are too many parameter combinations, so generate them sparsely
556 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
557 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
558 if sparsity < 13:
559 sparsity = 1
560 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
561 sparsity += 1
562 n = 0
563 for s in sorted(list(strides)):
564 for p in sorted(list(paddings)):
565 for d in sorted(list(dilations)):
566 if n % sparsity == 0:
567 # Determine the output shape
568 oh = (
569 ifm_shape[1]
570 - filter_shape[1]
571 - (filter_shape[1] - 1) * (d[0] - 1)
572 + 2 * p[0]
573 ) // s[0] + 1
574 ow = (
575 ifm_shape[2]
576 - filter_shape[2]
577 - (filter_shape[2] - 1) * (d[1] - 1)
578 + 2 * p[1]
579 ) // s[1] + 1
580 os = [ifm_shape[0], oh, ow, filter_shape[0]]
581 arg_list.append(
582 (
583 "st{}_pad{}_dilat{}_os{}".format(
584 "".join([str(x) for x in s]),
585 "".join([str(x) for x in p]),
586 "".join([str(x) for x in d]),
587 "x".join([str(x) for x in os]),
588 ),
589 [s, p, d, os],
590 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800591 )
Les Bell7aa69f42021-09-20 10:44:07 +0100592 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700593
594 return arg_list
595
596 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100597 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700598 arg_list = []
599 rank = len(shapeList[0])
600
Les Bell7ffccce2021-07-28 15:37:02 +0100601 # Exhaustively test combinations of padding on each side of each dimension
602 # - the range of padding values is defined by pad_min and pad_max
603 # - for padding >9, the name format needs to be more distinctive
604 pad_min, pad_max = 0, 1
605 pad_values = [x for x in range(pad_min, pad_max + 1)]
606 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
607 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700608
Les Bell7ffccce2021-07-28 15:37:02 +0100609 for paddings in shape_pad_values:
610 name = "pad"
611 for r in range(rank):
612 before, after = paddings[r]
613 name = f"{name}{before}{after}"
614 arg_list.append((name, [np.array(paddings)]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700615
616 return arg_list
617
618 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100619 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700620 arg_list = []
621
622 shape = shapeList[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800623 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700624
Les Bell7aa69f42021-09-20 10:44:07 +0100625 # Generate comprehensive argument lists
626 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
627 paddings = {x for x in itertools.product(*([p_vals] * 4))}
628 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
629 strides = {x for x in itertools.product(*([s_vals] * 2))}
630 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
631 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700632
Les Bell7aa69f42021-09-20 10:44:07 +0100633 # add some oversize argument values
634 bigStride = 7
635 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
636 bigKernel = 6
637 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
638 if max(shape) < 64:
639 # padding must be less than the kernel size
640 bigPadding = bigKernel - 1
641 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700642
Les Bell7aa69f42021-09-20 10:44:07 +0100643 # There are too many parameter combinations, so generate them sparsely
644 sparsity = len(paddings) * len(strides) * len(kernels) // 500 + 1
645 n = 0
646 for s in sorted(list(strides)):
647 for p in sorted(list(paddings)):
648 for k in sorted(list(kernels)):
649 if (n % sparsity == 0
650 # padding must not exceed the kernel size
651 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
652 # the padded shape must exceed the kernel size
653 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
654 ):
655 arg_list.append(
656 (
657 "st{}_kern{}_pad{}".format(
658 "".join([str(x) for x in s]),
659 "".join([str(x) for x in k]),
660 "".join([str(x) for x in p]),
661 ),
662 [s, p, k],
663 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800664 )
Les Bell7aa69f42021-09-20 10:44:07 +0100665 n += 1
666
Eric Kunzee5e26762020-10-13 16:11:07 -0700667 return arg_list
668
669 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100670 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700671 arg_list = []
672
673 # Enumerate the output types here
674 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800675 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700676 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800677 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700678 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800679 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700680 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800681 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700682 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800683 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700684 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800685 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700686
687 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800688 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700689
690 return arg_list
691
692 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100693 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700694 arg_list = []
695
696 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100697 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
698 if inDtype == DType.UINT8 and dtype != DType.INT8:
699 # The only output dtype for UINT8 is INT8, skip all other combinations
700 continue
701 if inDtype != DType.INT8 and dtype == DType.UINT8:
702 # The only input dtype for UINT8 is INT8, skip all other combinations
703 continue
704
Kevin Cheng550ccc52021-03-03 11:21:43 -0800705 for scale32 in [False, True]:
706 for double_round in [False, True]:
707 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700708
709 if inDtype == DType.INT48 and scale32:
710 # Illegal condition. Must be scale32=False
711 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100712 if double_round and not scale32:
713 # Illegal condition. ERROR_IF(!scale32 && double_round)
714 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700715
Kevin Cheng550ccc52021-03-03 11:21:43 -0800716 arg_list.append(
717 (
718 "out{}_sc{}_dr{}_pc{}".format(
719 DTypeNames[dtype],
720 int(scale32),
721 int(double_round),
722 int(per_channel),
723 ),
724 [dtype, scale32, double_round, per_channel],
725 )
726 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700727
728 return arg_list
729
Kevin Chengaee1fac2020-11-11 13:54:06 -0800730 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100731 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800732 arg_list = []
733
734 if dtype is DType.INT32:
735 for p in range(testGen.args.num_rand_permutations):
736
737 shift = testGen.randInt(0, 32)
738
Kevin Cheng550ccc52021-03-03 11:21:43 -0800739 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800740 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100741 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800742
743 return arg_list
744
745 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100746 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800747 arg_list = []
748
Kevin Cheng550ccc52021-03-03 11:21:43 -0800749 arg_list.append(("roundTrue", [True]))
750 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800751
752 return arg_list
753
Eric Kunzee5e26762020-10-13 16:11:07 -0700754 # Helper function for reshape. Gets some factors of a larger number.
755 @staticmethod
756 def getFactors(val, start=1):
757 factors = []
758
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100759 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700760 if (val % i) == 0:
761 factors.append(i)
762
763 return factors
764
765 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100766 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700767 arg_list = []
768
769 origShape = shapeList[0]
770
771 totalElements = 1
772 for s in origShape:
773 totalElements *= s
774
775 # This code is NOT fast. Fortunately, the numbers are fairly small.
776 factors = TosaArgGen.getFactors(totalElements)
777
778 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100779 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800780 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700781 continue
782
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100783 found = True
784 # escape_counter breaks while loop if it continues on for too long
785 escape_counter = 0
786 while found:
787 newShape = []
788 # Generate newShape ensuring it isn't a duplicate
789 remainingElements = totalElements
790 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100791 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100792 # pick rank-1 factors
793 newShape.append(shuffledFactors[0])
794 remainingElements = remainingElements // shuffledFactors[0]
795 shuffledFactors = testGen.rng.permutation(
796 TosaArgGen.getFactors(remainingElements)
797 )
798 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700799
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100800 # Toss in a -1 sometimes
801 minusOne = testGen.randInt(0, newRank * 4)
802 if minusOne < newRank:
803 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700804
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100805 # Check for duplicates
806 found = False
807 for name, other_shape in arg_list:
808 if other_shape[0] == newShape:
809 found = True
810 break
811
812 escape_counter += 1
813 if escape_counter >= 100:
814 break
815
816 if not found:
817 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700818
819 return arg_list
820
Eric Kunzee5e26762020-10-13 16:11:07 -0700821 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100822 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700823 arg_list = []
824
825 ifm_shape = shapeList[0]
826
Jeremy Johnsona6185572021-06-21 15:55:35 +0100827 # Get all permutations
828 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700829
Jeremy Johnsona6185572021-06-21 15:55:35 +0100830 # Limit to possible permutations from shape dimension or argument setting
831 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700832
Jeremy Johnsona6185572021-06-21 15:55:35 +0100833 # Get random permutation generator that uses all permutations
834 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700835
Jeremy Johnsona6185572021-06-21 15:55:35 +0100836 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -0700837 arg_list = [
838 ("perm{}".format(p), [random_permutations[p].tolist()])
839 for p in range(limit)
840 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700841 return arg_list
842
843 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100844 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700845 arg_list = []
846
847 ifm_shape = shapeList[0]
848 rank = len(ifm_shape)
849
850 for p in range(testGen.args.num_rand_permutations):
851 begin = []
852 size = []
853
Kevin Cheng550ccc52021-03-03 11:21:43 -0800854 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700855
856 for i in range(rank):
857 if ifm_shape[i] > 1:
858 begin.append(testGen.randInt(0, ifm_shape[i]))
859 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
860
861 # Invalid slice size?
862 if size[i] == 0:
863 valid = False
864 else:
865 begin.append(0)
866 size.append(1)
867
868 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800869 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700870 return arg_list
871
872 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100873 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700874 arg_list = []
875
876 ifm_shape = shapeList[0]
877 rank = len(ifm_shape)
878
879 for p in range(testGen.args.num_rand_permutations):
880
881 # Pick a few random, but small multiple values
882 # because otherwise this has a tendency to generate
883 # enormous tensors
884 multiples = []
885 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +0100886 if ifm_shape[i] > 1000:
887 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
888 multiples.append(1)
889 elif max(ifm_shape) > 1000:
890 multiples.append(2)
891 else:
892 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800893 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700894
895 return arg_list
896
897 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100898 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700899 arg_list = []
900
901 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +0100902 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700903
904 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +0100905 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100906 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +0100907 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800908 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +0100909 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100910 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +0100911 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800912 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800913 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800914 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +0100915 elif error_name == ErrorIf.WrongInputType:
916 # If an incorrect input type is used then we set a 'correct'
917 # output type to avoid other errors
918 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700919 else:
920 continue
921
922 for outputDType in outputDTypeList:
923 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -0700924 # Randomly generate legal output dimensions and shift
925 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +0100926 # A output_dim of 1 will cause offset to exceed allowed range
927 # so minimum value 2 produced below
928 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
929 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
930 output_dims[0] += 1
931 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
932 output_dims[1] += 1
933
Kevin Cheng77d0f762020-11-24 10:26:32 -0800934 in_center_h = (ifm_shape[1] - 1) / 2.0
935 in_center_w = (ifm_shape[2] - 1) / 2.0
936 out_center_h = (output_dims[0] - 1) / 2.0
937 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700938
Kevin Cheng77d0f762020-11-24 10:26:32 -0800939 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
940 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
941 fp_offset_y = in_center_h - fp_stride_y * out_center_h
942 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -0700943
Kevin Cheng77d0f762020-11-24 10:26:32 -0800944 if outputDType == DType.FLOAT:
945 shift = 0
946 stride = [0, 0]
947 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800948 stride_fp = [fp_stride_y, fp_stride_x]
949 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +0100950
951 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +0100952 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +0100953 testGen,
954 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +0100955 mode,
956 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +0100957 shapeList,
958 outputDType,
959 shift,
960 stride,
961 stride_fp,
962 offset,
963 offset_fp
964 )
Matthew Haddon848efb42021-09-09 12:30:53 +0100965 else:
966 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +0100967
Kevin Cheng550ccc52021-03-03 11:21:43 -0800968 arg_list.append(
969 (
970 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +0100971 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800972 output_dims[0],
973 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +0100974 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -0800975 stride_fp[0],
976 stride_fp[1],
977 offset_fp[0],
978 offset_fp[1],
979 ),
980 [
Matthew Haddon848efb42021-09-09 12:30:53 +0100981 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800982 stride,
983 offset,
984 shift,
985 stride_fp,
986 offset_fp,
987 output_dims,
988 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +0100989 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800990 ],
991 )
992 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800993 else:
994 shift = 11
995 unit = float(1 << shift)
996 stride_y = int(round(fp_stride_y * unit))
997 stride_x = int(round(fp_stride_x * unit))
998 offset_y = int(round(fp_offset_y * unit))
999 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001000
Kevin Cheng550ccc52021-03-03 11:21:43 -08001001 while (
Matthew Haddone86fd342021-09-07 16:12:21 +01001002 stride_y >= (16 << shift)
1003 or stride_x >= (16 << shift)
1004 or offset_y >= (16 << shift)
1005 or offset_x >= (16 << shift)
1006 or offset_y <= (-16 << shift)
1007 or offset_x <= (-16 << shift)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001008 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001009 shift = shift - 1
1010 unit = float(1 << shift)
1011 stride_y = int(round(fp_stride_y * unit))
1012 stride_x = int(round(fp_stride_x * unit))
1013 offset_y = int(round(fp_offset_y * unit))
1014 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001015
Kevin Cheng550ccc52021-03-03 11:21:43 -08001016 stride = [stride_y, stride_x]
1017 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001018
1019 stride_fp = [0.0, 0.0]
1020 offset_fp = [0.0, 0.0]
1021
Matthew Haddone86fd342021-09-07 16:12:21 +01001022 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001023 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001024 testGen,
1025 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001026 mode,
1027 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001028 shapeList,
1029 outputDType,
1030 shift,
1031 stride,
1032 stride_fp,
1033 offset,
1034 offset_fp
1035 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001036 else:
1037 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001038
Kevin Cheng550ccc52021-03-03 11:21:43 -08001039 arg_list.append(
1040 (
1041 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001042 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001043 shift,
1044 output_dims[0],
1045 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001046 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001047 stride[0],
1048 stride[1],
1049 offset[0],
1050 offset[1],
1051 ),
1052 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001053 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001054 stride,
1055 offset,
1056 shift,
1057 stride_fp,
1058 offset_fp,
1059 output_dims,
1060 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001061 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001062 ],
1063 )
1064 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001065
1066 return arg_list
1067
Matthew Haddon1c00b712021-10-01 15:51:03 +01001068 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001069 # CondIf generates the condition values here.
1070 # Convert to tensors in the build function, along with the
1071 # then and else blocks
1072 arg_list = []
1073
1074 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001075 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001076
1077 return arg_list
1078
Matthew Haddon1c00b712021-10-01 15:51:03 +01001079 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001080 # While loop: 0 iterations, 1, more than 1
1081 arg_list = []
1082
1083 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001084 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001085
1086 return arg_list
1087
Matthew Haddone86fd342021-09-07 16:12:21 +01001088class TosaErrorIfArgGen:
1089
1090 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001091 def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
Matthew Haddone86fd342021-09-07 16:12:21 +01001092
1093 if outputDType == DType.FLOAT:
1094 if error_name == ErrorIf.StrideSmallerEqualZero:
1095 stride_fp = testGen.rng.random(size=[2]) - 2
1096 elif error_name == ErrorIf.ShiftNotZero:
1097 shift = testGen.rng.integers(1, 5)
1098 elif error_name == ErrorIf.StrideLargerDimension:
1099 shape = shapeList[0]
1100 transform_height = testGen.rng.choice([False, True])
1101 if transform_height:
1102 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1103 else:
1104 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1105 else:
1106 if error_name == ErrorIf.StrideSmallerEqualZero:
1107 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1108 elif error_name == ErrorIf.ShiftSmallerOne:
1109 shift = testGen.rng.integers(-3, 1)
1110 if shift <= 0:
1111 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1112 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1113 else:
1114 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1115 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1116 elif error_name == ErrorIf.ShiftLargerEleven:
1117 shift = np.int16(testGen.rng.integers(12, 15))
1118 elif error_name == ErrorIf.StrideLargerDimension:
1119 shape = shapeList[0]
1120 transform_height = testGen.rng.choice([False, True])
1121 if transform_height:
1122 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1123 else:
1124 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1125 elif error_name == ErrorIf.StrideLargerEqualMax:
1126 stride = [(16 << shift) + 1, (16 << shift) + 1]
1127 elif error_name == ErrorIf.OffsetLargerEqualMax:
1128 offset = [(16 << shift) + 1, (16 << shift) + 1]
1129 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1130 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1131
Matthew Haddon1c00b712021-10-01 15:51:03 +01001132
Matthew Haddon848efb42021-09-09 12:30:53 +01001133 if error_name == ErrorIf.WrongOutputType:
1134 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1135 incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1136 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1137 incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
1138 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1139 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
1140 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1141 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
1142 elif dtype == DType.FLOAT:
1143 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
1144 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001145
Matthew Haddon848efb42021-09-09 12:30:53 +01001146 return shift, stride, stride_fp, offset, offset_fp, outputDType
1147
1148 @staticmethod
1149 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1150 # Mess up input/output tensors for ERROR_IF checks
1151 if error_name == "WrongInputList":
1152 add_input = testGen.rng.choice([True, False])
1153 if add_input:
1154 input_list.append('eiDummyInput')
1155 else:
1156 input_list = input_list[:-1]
1157 if error_name == "WrongOutputList":
1158 add_output = testGen.rng.choice([True, False])
1159 if add_output:
1160 output_list.append('eiDummyOutput')
1161 else:
1162 output_list = []
1163 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001164
1165class TosaErrorValidator:
1166
Matthew Haddon848efb42021-09-09 12:30:53 +01001167 @staticmethod
1168 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
1169 # Check ERROR_IF statements
1170
1171 for val_fcn in validator_fcns:
1172 val_result = val_fcn(True, **kwargs)
1173
1174 validator_name = val_result['error_name']
1175 error_result = val_result['error_result']
1176 error_reason = val_result['error_reason']
1177
1178 if error_result:
1179 if error_name == validator_name:
1180 serializer.setExpectedReturnCode(2, error_reason)
1181 else:
1182 print(f"Multiple ERROR_IF checks hit \nError required: {error_name}, Error_produced: {validator_name}")
1183 return None # Return None to delete test if wrong ERROR_IF is hit
1184 else:
1185 if error_name == validator_name:
1186 print(f"No ERROR_IF hit for {error_name}")
1187 return None
1188
1189 @staticmethod
1190 def evWrongInputType(check=False, **kwargs):
1191 all_dtypes = (DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1192
1193 # Find the unsupported input data types
1194 assert 'op' in kwargs
1195 op = kwargs['op']
1196 input_dtypes = op['types']
1197 wrong_input_dtypes = list(set(all_dtypes) - set(input_dtypes))
1198
1199 error_name = ErrorIf.WrongInputType
1200 param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
1201 error_result = False
1202 error_reason = "Input data type not supported for this operator"
1203
1204 if check:
1205 input_dtype = kwargs['input_dtype']
1206 if input_dtype not in input_dtypes:
1207 error_result = True
1208
1209 info_dict = {
1210 "error_name": error_name,
1211 "error_result": error_result,
1212 "error_reason": error_reason,
1213 "param_reqs": param_reqs
1214 }
1215 return info_dict
1216
1217 @staticmethod
1218 def evWrongOutputType(check=False, **kwargs):
1219 error_name = ErrorIf.WrongOutputType
1220 param_reqs = {"rank": None, "dtype": None, "shape": None}
1221 error_result = False
1222 error_reason = "Output data type not supported for this configuration of operator"
1223
1224 if check:
1225 input_dtype = kwargs['input_dtype']
1226 output_dtype = kwargs['output_dtype']
1227 mode = kwargs['mode']
1228
1229 if (
1230 (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
1231 (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
1232 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1233 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1234 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1235 ):
1236 error_result = True
1237
1238 info_dict = {
1239 "error_name": error_name,
1240 "error_result": error_result,
1241 "error_reason": error_reason,
1242 "param_reqs": param_reqs
1243 }
1244 return info_dict
1245
1246 @staticmethod
1247 def evWrongRank(check=False, **kwargs):
1248 all_ranks = (1, 2, 3, 4, 5)
1249
1250 # Make a list of incorrect ranks
1251 assert 'op' in kwargs
1252 op = kwargs['op']
1253 rmin, rmax = op['rank']
1254 rank_range = range(rmin, rmax + 1)
1255 incorrect_ranks = list(set(all_ranks) - set(rank_range))
1256 # Set minimum incorrect rank to 3 to avoid index error
1257 if op['op'] == Op.RESIZE:
1258 incorrect_ranks = [3, 5]
1259
1260 error_name = ErrorIf.WrongRank
1261 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1262 error_result = False
1263 error_reason = "Rank not supported for this operator"
1264
1265 if check:
1266 input_shape = kwargs['input_shape']
1267 if op['op'] == Op.RESIZE and len(input_shape.shape) != 4:
1268 error_result = True
1269
1270 info_dict = {
1271 "error_name": error_name,
1272 "error_result": error_result,
1273 "error_reason": error_reason,
1274 "param_reqs": param_reqs
1275 }
1276 return info_dict
1277
1278 @staticmethod
1279 def evWrongInputList(check=False, **kwargs):
1280 error_name = ErrorIf.WrongInputList
1281 param_reqs = {"rank": None, "dtype": None, "shape": None}
1282 error_result = False
1283 error_reason = "Op input list does not match expected input"
1284
1285 if check:
1286 op = kwargs['op']
1287 input_list = kwargs['input_list']
1288 num_operands = kwargs['num_operands']
1289 if len(input_list) != num_operands:
1290 error_result = True
1291
1292 info_dict = {
1293 "error_name": error_name,
1294 "error_result": error_result,
1295 "error_reason": error_reason,
1296 "param_reqs": param_reqs
1297 }
1298 return info_dict
1299
1300 @staticmethod
1301 def evWrongOutputList(check=False, **kwargs):
1302 error_name = ErrorIf.WrongOutputList
1303 param_reqs = {"rank": None, "dtype": None, "shape": None}
1304 error_result = False
1305 error_reason = "Op output list does not match expected output"
1306
1307 if check:
1308 output_list = kwargs['output_list']
1309 # Note this will be incorrect if an operator returns more than one output
1310 if len(output_list) != 1:
1311 error_result = True
1312
1313 info_dict = {
1314 "error_name": error_name,
1315 "error_result": error_result,
1316 "error_reason": error_reason,
1317 "param_reqs": param_reqs
1318 }
1319 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01001320
1321 @staticmethod
1322 def evMaxDimExceeded(check=False, **kwargs):
1323 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01001324 param_reqs = {
1325 "rank": [4,4],
1326 "dtype": [DType.INT8],
1327 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
1328 }
Matthew Haddone86fd342021-09-07 16:12:21 +01001329 error_result = False
1330 error_reason = "At least one maximum dimension is larger than 16384"
1331
1332 if check:
1333 input_shape = kwargs['input_shape'].shape
1334 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
1335 if ((input_shape[1] > 16384) or
1336 (input_shape[2] > 16384) or
1337 (output_shape[0] > 16384) or
1338 (output_shape[1] > 16384)):
1339 error_result = True
1340
1341 info_dict = {
1342 "error_name": error_name,
1343 "error_result": error_result,
1344 "error_reason": error_reason,
1345 "param_reqs": param_reqs
1346 }
1347 return info_dict
1348
1349 @staticmethod
1350 def evStrideSmallerEqualZero(check=False, **kwargs):
1351 error_name = ErrorIf.StrideSmallerEqualZero
1352 param_reqs = {"rank": None, "dtype": None, "shape": None}
1353 error_result = False
1354 error_reason = "Stride value smaller than or equal zero"
1355
1356 if check:
1357 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001358 output_dtype = kwargs['output_dtype']
1359 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
1360 stride = kwargs['stride'] # Work around wrong input/output type tests
1361 elif output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001362 stride = kwargs['stride_fp']
Matthew Haddon848efb42021-09-09 12:30:53 +01001363 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1364 stride = kwargs['stride_fp'] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01001365 else:
1366 stride = kwargs['stride']
1367
1368 if min(stride) <= 0:
1369 error_result = True
1370
1371 info_dict = {
1372 "error_name": error_name,
1373 "error_result": error_result,
1374 "error_reason": error_reason,
1375 "param_reqs": param_reqs
1376 }
1377 return info_dict
1378
1379 @staticmethod
1380 def evStrideLargerEqualMax(check=False, **kwargs):
1381 error_name = ErrorIf.StrideLargerEqualMax
1382 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1383 error_result = False
1384 error_reason = "Stride value larger than or equal to maximum value"
1385
1386 if check:
1387 shift = kwargs['shift']
1388 input_dtype = kwargs['input_dtype']
1389 stride = kwargs['stride']
1390 if input_dtype in [DType.INT8, DType.INT16]:
1391 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1392 error_result = True
1393 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1394 error_result = True
1395
1396 info_dict = {
1397 "error_name": error_name,
1398 "error_result": error_result,
1399 "error_reason": error_reason,
1400 "param_reqs": param_reqs
1401 }
1402 return info_dict
1403
1404
1405 @staticmethod
1406 def evStrideLargerDimension(check=False, **kwargs):
1407 error_name = ErrorIf.StrideLargerDimension
1408 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1409 error_result = False
1410 error_reason = "Stride value larger than or equal to H/W dimension"
1411
1412 if check:
1413 shape = kwargs['input_shape'].shape
1414 input_dtype = kwargs['input_dtype']
1415 stride = kwargs['stride_fp']
1416
1417 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1418 error_result = True
1419
1420 info_dict = {
1421 "error_name": error_name,
1422 "error_result": error_result,
1423 "error_reason": error_reason,
1424 "param_reqs": param_reqs
1425 }
1426 return info_dict
1427
1428
1429 @staticmethod
1430 def evOffsetSmallerEqualMin(check=False, **kwargs):
1431 error_name = ErrorIf.OffsetSmallerEqualMin
1432 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1433 error_result = False
1434 error_reason = "Offset value smaller than or equal to minimum value"
1435
1436 if check:
1437 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001438 output_dtype = kwargs['output_dtype']
1439 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001440 offset = kwargs['offset_fp']
1441 else:
1442 offset = kwargs['offset']
1443
1444 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
1445 error_result = True
1446 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
1447 error_result = True
1448
1449 info_dict = {
1450 "error_name": error_name,
1451 "error_result": error_result,
1452 "error_reason": error_reason,
1453 "param_reqs": param_reqs
1454 }
1455 return info_dict
1456
1457 @staticmethod
1458 def evOffsetLargerEqualMax(check=False, **kwargs):
1459 error_name = ErrorIf.OffsetLargerEqualMax
1460 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1461 error_result = False
1462 error_reason = "Offset value larger than or equal to maximum value"
1463
1464 if check:
1465 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001466 output_dtype = kwargs['output_dtype']
1467 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001468 offset = kwargs['offset_fp']
1469 else:
1470 offset = kwargs['offset']
1471
1472 if shift >= 0:
1473 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
1474 error_result = True
1475
1476 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
1477 error_result = True
1478 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
1479 error_result = True
1480
1481 info_dict = {
1482 "error_name": error_name,
1483 "error_result": error_result,
1484 "error_reason": error_reason,
1485 "param_reqs": param_reqs
1486 }
1487 return info_dict
1488
1489 @staticmethod
1490 def evShiftNotZero(check=False, **kwargs):
1491 error_name = ErrorIf.ShiftNotZero
1492 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1493 error_result = False
1494 error_reason = "Shift value must be zero for float input"
1495
1496 if check:
1497 shift = kwargs['shift']
1498 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001499 output_dtype = kwargs['output_dtype']
1500 if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
Matthew Haddone86fd342021-09-07 16:12:21 +01001501 error_result = True
1502
1503 info_dict = {
1504 "error_name": error_name,
1505 "error_result": error_result,
1506 "error_reason": error_reason,
1507 "param_reqs": param_reqs
1508 }
1509 return info_dict
1510
1511
1512 @staticmethod
1513 def evShiftSmallerOne(check=False, **kwargs):
1514 error_name = ErrorIf.ShiftSmallerOne
1515 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1516 error_result = False
1517 error_reason = "Shift value smaller than one"
1518
1519 if check:
1520 shift = kwargs['shift']
1521 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001522 output_dtype = kwargs['output_dtype']
1523 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001524 error_result = True
1525
1526 info_dict = {
1527 "error_name": error_name,
1528 "error_result": error_result,
1529 "error_reason": error_reason,
1530 "param_reqs": param_reqs
1531 }
1532 return info_dict
1533
1534 @staticmethod
1535 def evShiftLargerEleven(check=False, **kwargs):
1536 error_name = ErrorIf.ShiftLargerEleven
1537 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1538 error_result = False
1539 error_reason = "Shift value larger than eleven"
1540
1541 if check:
1542 shift = kwargs['shift']
1543 if shift > 11:
1544 error_result = True
1545
1546 info_dict = {
1547 "error_name": error_name,
1548 "error_result": error_result,
1549 "error_reason": error_reason,
1550 "param_reqs": param_reqs
1551 }
1552 return info_dict
1553
1554
Matthew Haddonb724efc2021-08-25 16:40:29 +01001555class TosaInvalidValidator:
1556
1557 @staticmethod
1558 def ivWrongDataTypeOrModeResize(**kwargs):
1559 input_dtype = kwargs["input_dtype"]
1560 args = kwargs["args"]
1561 mode = args[0]
1562 stride = args[1]
1563 stride_fp = args[4]
1564 output_dtype = args[8]
1565
1566 if mode == ResizeMode.BILINEAR:
1567 # Invalid output data type / Invalid input datatype
1568 return (
1569 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
1570 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
1571 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
1572 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
1573 )
1574 elif mode == ResizeMode.NEAREST:
1575 # Invalid output data type / Invalid input datatype
1576 return (
1577 (input_dtype != output_dtype) or
1578 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
1579 )
1580 else:
1581 # Invalid resize mode
1582 return True
1583
1584 @staticmethod
1585 def ivBadStride(**kwargs):
1586 input_dtype = kwargs["input_dtype"]
1587 args = kwargs["args"]
1588 stride_x = args[1][0]
1589 stride_y = args[1][1]
1590 stride_fp_x = args[4][0]
1591 stride_fp_y = args[4][1]
1592
1593 if input_dtype == DType.FLOAT:
1594 if stride_fp_x <= 0 or stride_fp_y <= 0:
1595 # Negative or zero stride
1596 return True
1597 else:
1598 if stride_x <= 0 or stride_y <= 0:
1599 # Negative or zero stride
1600 return True
1601 return False
1602
1603
Matthew Haddonb724efc2021-08-25 16:40:29 +01001604 @staticmethod
1605 def ivHeightWidthSmallerZero(**kwargs):
1606 opName = kwargs['opName']
1607
1608 inputShapes = kwargs['shapeList']
1609 input = inputShapes[0]
1610 if not opName.endswith("pool2d"):
1611 filter = inputShapes[1]
1612
1613 args = kwargs['args']
1614 strides = args[0]
1615 padding = args[1]
1616 dilations = args[2]
1617 if opName.endswith("pool2d"):
1618 kernel = args[2]
1619
1620 if opName.startswith('conv2d'):
1621 h = (
1622 input[1]
1623 - filter[1]
1624 - (filter[1] - 1) * (dilations[0] - 1)
1625 + padding[0]
1626 + padding[1]
1627 ) // strides[0] + 1
1628
1629 w = (
1630 input[2]
1631 - filter[2]
1632 - (filter[2] - 1) * (dilations[1] - 1)
1633 + padding[2]
1634 + padding[3]
1635 ) // strides[1] + 1
1636 elif opName.startswith("depthwise_conv2d"):
1637 h = (
1638 input[1]
1639 - filter[0]
1640 - (filter[0] - 1) * (dilations[0] - 1)
1641 + padding[0]
1642 + padding[1]
1643 ) // strides[0] + 1
1644
1645 w = (
1646 input[2]
1647 - filter[1]
1648 - (filter[1] - 1) * (dilations[1] - 1)
1649 + padding[2]
1650 + padding[3]
1651 ) // strides[1] + 1
1652 elif opName.endswith("pool2d"):
1653 h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
1654 w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
1655 else:
1656 assert False, "Unrecognized Op"
1657
1658 if h <= 0 or w <= 0:
1659 # Invalid parameter combination
1660 return True
1661 return False
1662
1663 @staticmethod
1664 def ivNonPositiveOutputShape(**kwargs):
1665 args = kwargs['args']
1666 output_shape = args[3]
1667 if output_shape[1] <= 0 or output_shape[2] <= 0:
1668 # Negative output shape
1669 return True
1670 return False
1671
1672
Kevin Cheng550ccc52021-03-03 11:21:43 -08001673
Eric Kunzee5e26762020-10-13 16:11:07 -07001674class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001675 # Maximum rank of tensor supported by test generator.
1676 TOSA_TENSOR_MAX_RANK = 6
1677
Eric Kunzee5e26762020-10-13 16:11:07 -07001678 def __init__(self, args):
1679 self.args = args
1680 self.basePath = args.output_dir
1681 self.random_seed = args.random_seed
1682 self.ser = None
1683 self.rng = np.random.default_rng(self.random_seed)
1684 self.createDynamicOpLists()
1685 self.initOpListDefaults()
1686 self.quantGen = TosaQuantGen()
1687 # Force makeShape to do a specific starting shape
1688 self.targetted_shape = None
1689
1690 def createSerializer(self, opName, testPath):
1691 self.testPath = os.path.join(opName, testPath)
1692
1693 fullPath = os.path.join(self.basePath, self.testPath)
1694 os.makedirs(fullPath, exist_ok=True)
1695 self.ser = ts.TosaSerializer(fullPath)
1696
1697 def getSerializer(self):
1698 return self.ser
1699
1700 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001701 with open(
1702 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
1703 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07001704 fd.write(self.ser.serialize())
1705
Kevin Cheng550ccc52021-03-03 11:21:43 -08001706 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
1707 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07001708
Matthew Haddon74567092021-07-16 15:38:20 +01001709 def resetRNG(self, seed=None):
1710 if seed == None:
1711 seed = self.random_seed + 1
1712 self.rng = np.random.default_rng(seed)
1713
Eric Kunzee5e26762020-10-13 16:11:07 -07001714 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07001715 if dtype == DType.BOOL:
1716 np_dt = np.bool
1717 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07001718 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07001719 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07001720 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001721 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001722 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
1723 elif dtype == DType.UINT8:
1724 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001725 elif dtype == DType.INT16:
1726 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
1727 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001728 return np.int32(
1729 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
1730 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001731 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001732 return np.int64(
1733 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
1734 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001735 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001736 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001737 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001738 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001739
Kevin Cheng989cb052021-04-28 16:29:44 -07001740 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07001741 placeholders = []
1742
Kevin Cheng989cb052021-04-28 16:29:44 -07001743 assert len(shape_list) == len(dtype_list)
1744
1745 for idx, shape in enumerate(shape_list):
1746 arr = self.getRandTensor(shape, dtype_list[idx])
1747 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001748
1749 return placeholders
1750
Kevin Cheng989cb052021-04-28 16:29:44 -07001751 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07001752 consts = []
1753
Kevin Cheng989cb052021-04-28 16:29:44 -07001754 assert len(shape_list) == len(dtype_list)
1755
1756 for idx, shape in enumerate(shape_list):
1757 arr = self.getRandTensor(shape, dtype_list[idx])
1758 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001759
1760 return consts
1761
1762 def makeShape(self, rank):
1763 if self.targetted_shape:
1764 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001765 return np.int32(
1766 self.rng.integers(
1767 low=self.args.tensor_shape_range[0],
1768 high=self.args.tensor_shape_range[1],
1769 size=rank,
1770 )
1771 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001772
1773 def setTargetShape(self, shape):
1774 self.targetted_shape = shape
1775
1776 def randInt(self, low=0, high=256):
1777 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
1778
1779 def getRandNumberDType(self, dtype):
1780 if dtype == DType.FLOAT:
1781 return self.rng.random()
1782 elif dtype == DType.BOOL:
1783 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07001784 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07001785 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07001786 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001787 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001788 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07001789 elif dtype == DType.INT16:
1790 low, high = (-32768, 32768)
1791 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001792 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07001793 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001794 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07001795 # Special size
1796 return np.int64(self.rng.integers(low, high, size=1))[0]
1797 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001798 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001799
1800 return np.int32(self.rng.integers(low, high, size=1))[0]
1801
1802 def shapeStr(self, shape):
1803
1804 sStr = []
1805 # Convert to strings
1806 for i in shape:
1807 sStr.append(str(i))
1808
Kevin Cheng550ccc52021-03-03 11:21:43 -08001809 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001810
1811 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07001812 if isinstance(t, list):
1813 assert len(t) >= 2
1814 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001815 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001816 if t == DType.BOOL:
1817 return "b"
1818 elif t == DType.INT4:
1819 return "i4"
1820 elif t == DType.INT8:
1821 return "i8"
1822 elif t == DType.UINT8:
1823 return "u8"
1824 elif t == DType.INT16:
1825 return "i16"
1826 elif t == DType.INT32:
1827 return "i32"
1828 elif t == DType.INT48:
1829 return "i48"
1830 elif t == DType.FLOAT:
1831 return "float"
1832 else:
1833 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001834
1835 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001836 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08001837 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07001838 return 4
1839 elif t == DType.INT8:
1840 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08001841 elif t == DType.UINT8:
1842 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07001843 elif t == DType.INT16:
1844 return 16
1845 elif t == DType.INT32:
1846 return 32
1847 elif t == DType.INT48:
1848 return 48
1849 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001850 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001851
1852 # Argument generators
1853 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
1854 # Where the string descriptor is used to generate the test name and
1855 # The build_fcn_arg_list is expanded and passed to the operator test
1856 # build function
1857
Kevin Cheng550ccc52021-03-03 11:21:43 -08001858 def build_unary(self, op, a, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001859 result_tens = OutputShaper.unaryOp(self.ser, a)
Matthew Haddon848efb42021-09-09 12:30:53 +01001860 # build_placeholder returns an int, ABS/other ops does not
1861 if isinstance(op, int):
1862 self.ser.addOperator(op, [a.name], [result_tens.name], None, qinfo)
1863 else:
1864 self.ser.addOperator(op['op'], [a.name], [result_tens.name], None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07001865 return result_tens
1866
1867 def build_binary_broadcast(self, op, a, b):
1868 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01001869 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001870 return result_tens
1871
1872 def build_binary_nonbroadcast(self, op, a, b):
1873 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01001874 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001875 return result_tens
1876
Kevin Chengaee1fac2020-11-11 13:54:06 -08001877 def build_arithmetic_right_shift(self, op, a, b, round):
1878 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1879
1880 attr = ts.TosaSerializerAttribute()
1881 attr.ArithmeticRightShiftAttribute(round)
1882
Matthew Haddon848efb42021-09-09 12:30:53 +01001883 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08001884 return result_tens
1885
1886 def build_mul(self, op, a, b, shift):
Eric Kunzee5e26762020-10-13 16:11:07 -07001887 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1888
1889 # Special for multiply:
1890 # Force the result to INT32 for INT types
1891 if a.dtype != DType.FLOAT:
1892 result_tens.setDtype(DType.INT32)
1893
Kevin Chengaee1fac2020-11-11 13:54:06 -08001894 attr = ts.TosaSerializerAttribute()
1895 attr.MulAttribute(shift)
1896
Matthew Haddon848efb42021-09-09 12:30:53 +01001897 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001898 return result_tens
1899
1900 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001901 # Constant size depending on type, random values
1902 if a.dtype == DType.INT16:
Kevin Chengacb550f2021-06-29 15:32:19 -07001903 table_dtype = DType.INT16
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001904 table_arr = self.getRandTensor([513], table_dtype)
1905 else:
1906 assert a.dtype == DType.INT8
1907 table_dtype = DType.INT8
1908 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001909
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001910 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
1911 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01001912 self.ser.addOperator(op['op'], [a.name, table_tens.name], [result_tens.name], None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001913
1914 return result_tens
1915
1916 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07001917 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01001918 self.ser.addOperator(op['op'], [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001919 return result_tens
1920
1921 def build_comparison(self, op, a, b):
1922 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01001923 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001924 return result_tens
1925
1926 def build_argmax(self, op, a, axis):
1927 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
1928
1929 attr = ts.TosaSerializerAttribute()
1930 attr.AxisAttribute(axis)
1931
Matthew Haddon848efb42021-09-09 12:30:53 +01001932 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001933 return result_tens
1934
Matthew Haddonb724efc2021-08-25 16:40:29 +01001935 def build_pool2d(self, op, input, stride, pad, kernel, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001936 result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
1937
1938 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07001939 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07001940
Matthew Haddon848efb42021-09-09 12:30:53 +01001941 self.ser.addOperator(op['op'], [input.name], [result_tens.name], attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07001942 return result_tens
1943
1944 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001945 assert len(padding) == 4
1946 result_tens = OutputShaper.conv2dOp(
1947 self.ser, ifm, filter, strides, padding, dilations
1948 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001949
1950 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07001951 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001952
Kevin Cheng550ccc52021-03-03 11:21:43 -08001953 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01001954 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001955 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001956 return result_tens
1957
Kevin Cheng1533b852021-09-01 12:51:58 -07001958 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
1959 assert len(padding) == 6
1960 result_tens = OutputShaper.conv3dOp(
1961 self.ser, ifm, filter, strides, padding, dilations
1962 )
1963
1964 attr = ts.TosaSerializerAttribute()
1965 attr.ConvAttribute(padding, strides, dilations)
1966
1967 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01001968 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng1533b852021-09-01 12:51:58 -07001969 )
1970 return result_tens
1971
Kevin Cheng550ccc52021-03-03 11:21:43 -08001972 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07001973 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001974 ):
1975 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07001976 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
1977
1978 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07001979 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07001980
Kevin Cheng550ccc52021-03-03 11:21:43 -08001981 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01001982 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001983 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001984 return result_tens
1985
Kevin Cheng550ccc52021-03-03 11:21:43 -08001986 def build_depthwise_conv2d(
1987 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
1988 ):
1989 result_tens = OutputShaper.depthwiseConv2dOp(
1990 self.ser, ifm, filter, strides, padding, dilations
1991 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001992
1993 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07001994 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001995
Kevin Cheng550ccc52021-03-03 11:21:43 -08001996 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01001997 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001998 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001999 return result_tens
2000
2001 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
2002 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
2003
Kevin Cheng550ccc52021-03-03 11:21:43 -08002004 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002005 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002006 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002007 return result_tens
2008
2009 def build_matmul(self, op, a, b, qinfo):
2010 result_tens = OutputShaper.matmulOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002011 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002012 return result_tens
2013
2014 def build_reduce(self, op, a, axis):
2015 result_tens = OutputShaper.reduceOp(self.ser, a, axis)
2016
2017 attr = ts.TosaSerializerAttribute()
2018 attr.AxisAttribute(axis)
2019
Matthew Haddon848efb42021-09-09 12:30:53 +01002020 self.ser.addOperator(op['op'], [a.name], result_tens.name, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002021 return result_tens
2022
2023 def build_clamp(self, op, a):
2024 result_tens = OutputShaper.unaryOp(self.ser, a)
2025
2026 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01002027 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07002028
2029 if a.dtype == DType.FLOAT:
2030 attr.ClampAttribute(0, 0, min(v), max(v))
2031 else:
2032 attr.ClampAttribute(min(v), max(v), 0, 0)
2033
Matthew Haddon848efb42021-09-09 12:30:53 +01002034 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002035 return result_tens
2036
2037 def build_leaky_relu(self, op, a):
2038 result_tens = OutputShaper.unaryOp(self.ser, a)
2039 attr = ts.TosaSerializerAttribute()
2040
2041 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
2042
Matthew Haddon848efb42021-09-09 12:30:53 +01002043 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002044 return result_tens
2045
2046 # Needs an additional type/input
2047 def build_prelu(self, op, a):
2048 result_tens = OutputShaper.unaryOp(self.ser, a)
2049
Matthew Haddon848efb42021-09-09 12:30:53 +01002050 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002051 return result_tens
2052
Eric Kunzee5e26762020-10-13 16:11:07 -07002053 def build_sigmoid(self, op, a):
2054 result_tens = OutputShaper.unaryOp(self.ser, a)
Matthew Haddon848efb42021-09-09 12:30:53 +01002055 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002056 return result_tens
2057
2058 def build_tanh(self, op, a):
2059 result_tens = OutputShaper.unaryOp(self.ser, a)
Matthew Haddon848efb42021-09-09 12:30:53 +01002060 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002061 return result_tens
2062
Matthew Haddon818ab902021-07-27 09:12:49 +01002063 def build_concat(self, op, *a):
Kevin Cheng93a16282021-08-31 16:14:03 -07002064 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01002065
2066 # To store variable length list of input tensors we need to store axis along with it
2067 axis = a[-1]
2068 a = a[:-1]
2069
2070 result_tens = OutputShaper.concatOp(self.ser, axis, *a)
Eric Kunzee5e26762020-10-13 16:11:07 -07002071
2072 attr = ts.TosaSerializerAttribute()
2073 attr.AxisAttribute(axis)
2074
Matthew Haddon818ab902021-07-27 09:12:49 +01002075 input_tensor_names = []
2076 for tensor in a:
2077 input_tensor_names.append(tensor.name)
2078
Matthew Haddon848efb42021-09-09 12:30:53 +01002079 self.ser.addOperator(op['op'], input_tensor_names, [result_tens.name], attr)
2080 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002081
2082 def build_pad(self, op, a, padding, qinfo):
2083 result_tens = OutputShaper.padOp(self.ser, a, padding)
2084
2085 # Need to turn the padding array into a TOSA tensor here.
2086 # This is one of the few tensor operands that does not get
2087 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08002088 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07002089
Kevin Cheng550ccc52021-03-03 11:21:43 -08002090 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002091 op['op'], [a.name, padding_tens.name], [result_tens.name], None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002092 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002093 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002094
2095 def build_reshape(self, op, a, newShape):
2096 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
2097
2098 attr = ts.TosaSerializerAttribute()
2099 attr.ReshapeAttribute(newShape)
2100
Matthew Haddon848efb42021-09-09 12:30:53 +01002101 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002102 return result_tens
2103
2104 def build_reverse(self, op, a, axis):
2105 result_tens = OutputShaper.unaryOp(self.ser, a)
2106
2107 attr = ts.TosaSerializerAttribute()
2108 attr.AxisAttribute(axis)
2109
Matthew Haddon848efb42021-09-09 12:30:53 +01002110 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002111 return result_tens
2112
2113 def build_transpose(self, op, a, perms):
2114 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
2115
Kevin Cheng550ccc52021-03-03 11:21:43 -08002116 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07002117
Matthew Haddon848efb42021-09-09 12:30:53 +01002118 self.ser.addOperator(op['op'], [a.name, perms_tens.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002119 return result_tens
2120
2121 def build_slice(self, op, a, begin, size):
2122 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
2123
2124 attr = ts.TosaSerializerAttribute()
2125 attr.SliceAttribute(begin, size)
2126
Matthew Haddon848efb42021-09-09 12:30:53 +01002127 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002128 return result_tens
2129
2130 def build_tile(self, op, a, multiples):
2131 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
2132
2133 attr = ts.TosaSerializerAttribute()
2134 attr.TileAttribute(multiples)
2135
Matthew Haddon848efb42021-09-09 12:30:53 +01002136 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002137 return result_tens
2138
Kevin Cheng77d0f762020-11-24 10:26:32 -08002139 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07002140
2141 # Create a new indicies tensor
2142 # here with data that doesn't exceed the dimensions of the values tensor
2143
Kevin Cheng550ccc52021-03-03 11:21:43 -08002144 K = values.shape[1] # K
2145 W = self.randInt(
2146 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
2147 ) # W
2148 indicies_arr = np.int32(
2149 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
2150 ) # (N, W)
2151 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002152
Kevin Cheng77d0f762020-11-24 10:26:32 -08002153 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07002154
Matthew Haddon848efb42021-09-09 12:30:53 +01002155 self.ser.addOperator(op['op'], [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002156
2157 return result_tens
2158
Kevin Cheng77d0f762020-11-24 10:26:32 -08002159 def build_scatter(self, op, values_in, input):
2160
2161 # Create a new indicies tensor
2162 # here with data that doesn't exceed the dimensions of the values_in tensor
2163
Kevin Cheng550ccc52021-03-03 11:21:43 -08002164 K = values_in.shape[1] # K
2165 W = input.shape[1] # W
2166 indicies_arr = np.int32(
2167 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
2168 ) # (N, W)
2169 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002170
2171 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
2172
Kevin Cheng550ccc52021-03-03 11:21:43 -08002173 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002174 op['op'], [values_in.name, indicies.name, input.name], [result_tens.name]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002175 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08002176
2177 return result_tens
2178
Matthew Haddon848efb42021-09-09 12:30:53 +01002179
Kevin Cheng550ccc52021-03-03 11:21:43 -08002180 def build_resize(
2181 self,
2182 op,
2183 input,
2184 mode,
2185 stride,
2186 offset,
2187 shift,
2188 stride_fp,
2189 offset_fp,
2190 output_dims,
2191 input_dtype,
2192 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01002193 validator_fcns,
2194 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002195 ):
2196 result_tens = OutputShaper.resizeOp(
2197 self.ser,
2198 input,
2199 mode,
2200 stride,
2201 offset,
2202 shift,
2203 stride_fp,
2204 offset_fp,
2205 output_dims,
2206 input_dtype,
2207 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01002208 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08002209 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002210
Matthew Haddon848efb42021-09-09 12:30:53 +01002211 # Invalidate Input/Output list for error if checks.
2212 input_list = [input.name]
2213 output_list = [result_tens.name]
2214 pCount, cCount = op["operands"]
2215 num_operands = pCount + cCount
2216 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
Matthew Haddone86fd342021-09-07 16:12:21 +01002217
Matthew Haddon848efb42021-09-09 12:30:53 +01002218 TosaErrorValidator.evValidateErrorIfs(
2219 self.ser,
2220 validator_fcns,
2221 error_name,
2222 op=op,
2223 mode=mode,
2224 shift=shift,
2225 input_dtype=input_dtype,
2226 output_dtype=output_dtype,
2227 input_shape=input,
2228 output_shape=output_dims,
2229 offset=offset,
2230 offset_fp=offset_fp,
2231 stride=stride,
2232 stride_fp=stride_fp,
2233 input_list=input_list,
2234 output_list=output_list,
2235 num_operands=num_operands,
2236 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002237
Eric Kunzee5e26762020-10-13 16:11:07 -07002238 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08002239
Kevin Cheng550ccc52021-03-03 11:21:43 -08002240 attr.ResizeAttribute(
2241 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
2242 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002243
Matthew Haddon848efb42021-09-09 12:30:53 +01002244 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002245 return result_tens
2246
2247 def build_identityn(self, op, val, val2):
2248
Kevin Cheng550ccc52021-03-03 11:21:43 -08002249 result_tens = OutputShaper.unaryOp(self.ser, val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002250 result_tens2 = OutputShaper.unaryOp(self.ser, val2)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002251 self.ser.addOperator(
2252 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2253 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002254 return result_tens
2255
Kevin Cheng17e92022021-10-01 14:33:33 -07002256 def build_const(self, op, val):
2257 self.ser.addOutputTensor(val)
2258 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07002259
2260 # Type Conversion
2261 def build_cast(self, op, val, out_dtype):
2262 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01002263 self.ser.addOperator(op['op'], [val.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002264 return result_tens
2265
2266 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
2267 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
2268
2269 if per_channel:
2270 nc = val.shape[-1]
2271 else:
2272 nc = 1
2273
2274 in_type_width = self.typeWidth(val.dtype)
2275 out_type_width = self.typeWidth(out_dtype)
2276
Kevin Cheng3a478572021-01-22 17:21:02 -08002277 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002278 input_zp = self.randInt(-128, 128)
2279 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002280 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002281 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07002282 in_type_width = in_type_width + 1
2283 else:
2284 input_zp = 0
2285
Kevin Cheng3a478572021-01-22 17:21:02 -08002286 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002287 output_zp = self.randInt(-128, 128)
2288 out_type_width = out_type_width + 1
2289 elif out_dtype == DType.UINT8:
2290 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07002291 out_type_width = out_type_width + 1
2292 else:
2293 output_zp = 0
2294
2295 # Calculate scale based on:
2296 # scale = a *(2^output_width)/(2^input_width))
2297
2298 a = np.float32(self.rng.random(size=[nc]))
2299 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2300
2301 if scale32:
2302 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002303 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002304 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2305 else:
2306 # Cap the scaling at 2^15 - 1 for scale16
2307 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2308
Kevin Cheng550ccc52021-03-03 11:21:43 -08002309 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002310
2311 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2312 shift_arr = np.int32(np.zeros(shape=[nc]))
2313
2314 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002315 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2316 scale_arr[i], scale32
2317 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002318
Kevin Cheng550ccc52021-03-03 11:21:43 -08002319 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07002320
2321 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002322 attr.RescaleAttribute(
2323 input_zp,
2324 output_zp,
2325 multiplier_arr,
2326 shift_arr,
2327 scale32,
2328 double_round,
2329 per_channel,
2330 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002331
Matthew Haddon848efb42021-09-09 12:30:53 +01002332 self.ser.addOperator(op['op'], [val.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002333 return result_tens
2334
2335 def build_cond_if_const(self, op, then_tens, else_tens, cond):
2336 # For cond_if with constants, we're supplied with then/else tensors that we ignore
2337 # (except for the generated shap) and the condition. Build Then/Else blocks
2338 # and fill them with const nodes for the body.
2339
2340 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002341 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07002342
2343 # Make then/else tensors
2344 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01002345 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2346 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002347
2348 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002349 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002350
2351 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002352 then_block = "THEN_BLOCK"
2353 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002354 attr = ts.TosaSerializerAttribute()
2355 attr.CondIfAttribute(then_block, else_block)
2356
2357 # Finally, build the op and the two blocks
Matthew Haddon848efb42021-09-09 12:30:53 +01002358 self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002359
2360 self.ser.startBasicBlock(then_block)
2361 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002362 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002363 self.ser.addOutputTensor(then_tens)
2364
2365 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002366 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002367 self.ser.addOutputTensor(else_tens)
2368
2369 return result_tens
2370
2371 def build_cond_if_binary(self, op, a, b, cond):
2372 # For cond_if with a binary op in the then/else blocks, take a and b and
2373 # alternately add or subtract them based on the condition
2374
2375 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002376 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07002377
Kevin Cheng550ccc52021-03-03 11:21:43 -08002378 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002379
2380 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002381 then_block = "THEN_BLOCK"
2382 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002383 attr = ts.TosaSerializerAttribute()
2384 attr.CondIfAttribute(then_block, else_block)
2385
2386 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002387 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002388 op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002389 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002390
2391 self.ser.startBasicBlock(then_block)
2392 self.ser.addInputTensor(a)
2393 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002394 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002395 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
2396
2397 self.ser.startBasicBlock(else_block)
2398 self.ser.addInputTensor(a)
2399 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002400 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002401 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
2402
2403 return result_tens
2404
2405 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002406 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002407
Kevin Cheng550ccc52021-03-03 11:21:43 -08002408 cond_block = "COND_BLOCK"
2409 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002410
2411 attr = ts.TosaSerializerAttribute()
2412 attr.WhileLoopAttribute(cond_block, body_block)
2413
2414 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002415 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002416 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002417 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002418
2419 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002420 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2421 a_out = self.ser.addIntermediate(a.shape, a.dtype)
2422 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002423
2424 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002425 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002426 op['op'],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002427 [iter.name, a.name, acc.name],
2428 [iter_out.name, a_out.name, acc_out.name],
2429 attr,
2430 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002431 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002432
2433 # COND block (input: iter, output: cond_tens )
2434 self.ser.startBasicBlock(cond_block)
2435 self.ser.addInputTensor(iter)
2436 self.ser.addInputTensor(a)
2437 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002438 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
2439 cond_tens = self.ser.addOutput([], DType.BOOL)
2440 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002441
2442 # BODY block (input: a, acc, iter, output: a, acc, iter)
2443 # Note that local intermediate tensors need to be declared here for the outputs
2444 self.ser.startBasicBlock(body_block)
2445 self.ser.addInputTensor(iter)
2446 self.ser.addInputTensor(a)
2447 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002448 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
2449 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2450 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002451 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2452 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2453 self.ser.addOutputTensor(iter_body_out)
2454 self.ser.addOutputTensor(a)
2455 self.ser.addOutputTensor(acc_body_out)
2456
2457 return acc_out
2458
Matthew Haddon1c00b712021-10-01 15:51:03 +01002459 def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
2460 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2461 default_test_rank_range = range(1, 5)
2462 if not shapeFilter:
2463 shapeFilter = [None]
2464
2465 # Calculate the filters based on what is requested and what the operator allows
2466 rmin, rmax = op["rank"]
2467 if rankFilter is not None:
2468 cleanRankFilter = []
2469 # Ensure rankFilter values are allowed by operator
2470 for rank in rankFilter:
2471 if rank >= rmin and rank <= rmax:
2472 cleanRankFilter.append(rank)
2473 elif rankFilter is None and shapeFilter[0] is None:
2474 cleanRankFilter = []
2475 # Ensure default behaviour is bounded by default range or by operator, whichever is smaller.
2476 rankRange = range(rmin, rmax + 1)
2477 for rank in rankRange:
2478 if rank >= min(default_test_rank_range) and rank <= max(default_test_rank_range):
2479 cleanRankFilter.append(rank)
2480 else:
2481 cleanRankFilter = range(rmin, rmax + 1)
2482
2483 dtypes = op["types"]
2484 if dtypeFilter is not None:
2485 cleanDtypeFilter = []
2486 # Ensure filtered dtypes are allowed by operator
2487 for dtype in dtypeFilter:
2488 if dtype in dtypes:
2489 cleanDtypeFilter.append(dtype)
2490 else:
2491 cleanDtypeFilter = dtypes
2492
2493 if testType == 'positive':
2494 filterDict = {
2495 'shapeFilter': shapeFilter,
2496 'rankFilter': cleanRankFilter,
2497 'dtypeFilter': cleanDtypeFilter
2498 }
2499 return filterDict
2500 elif testType == 'negative':
2501 validator_info = validator(check=False, op=op)
2502 error_arguments = validator_info['param_reqs']
2503
2504 #Set parameters as required
2505 if error_arguments['rank'] != None:
2506 rankFilter = error_arguments['rank']
2507 else:
2508 rankFilter = cleanRankFilter
2509
2510 if error_arguments['dtype'] != None:
2511 dtypeFilter = error_arguments['dtype']
2512 else:
2513 dtypeFilter = cleanDtypeFilter
2514
2515 if error_arguments['shape'] != None:
2516 shapeFilter = error_arguments['shape']
2517 else:
2518 shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small
2519
2520 filterDict = {
2521 'shapeFilter': shapeFilter,
2522 'rankFilter': rankFilter,
2523 'dtypeFilter': dtypeFilter
2524 }
2525 return filterDict
2526
2527
Kevin Cheng550ccc52021-03-03 11:21:43 -08002528 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01002529 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08002530 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002531
2532 try:
2533 op = self.TOSA_OP_LIST[opName]
2534 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002535 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002536
2537 # Initialize a new random number generator
2538 self.rng = np.random.default_rng(self.random_seed)
2539
Kevin Cheng550ccc52021-03-03 11:21:43 -08002540 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002541
Eric Kunzee5e26762020-10-13 16:11:07 -07002542 # Test list consists of a tuple of:
2543 # (opName, testNameStr, dtype, shapeList, argumentsList)
2544 testList = []
Matthew Haddon1c00b712021-10-01 15:51:03 +01002545 if testType == 'negative' and "error_if_validators" in op:
2546 error_if_validators = op["error_if_validators"]
2547 else:
2548 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002549
Matthew Haddon1c00b712021-10-01 15:51:03 +01002550 for validator in error_if_validators:
2551 if validator is not None:
2552 error_name = validator(check=False, op=op)['error_name']
2553 #print("error_name: ", error_name)
2554 else:
2555 error_name = None
2556
2557 filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator)
2558 cleanRankFilter = filterDict['rankFilter']
2559 cleanDtypeFilter = filterDict['dtypeFilter']
2560 cleanShapeFilter = filterDict['shapeFilter']
2561 #print(f"Filters: S {shapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
2562
2563 for r in cleanRankFilter:
Kevin Cheng1533b852021-09-01 12:51:58 -07002564 if opName.startswith("conv3d"):
2565 assert r == 5, "conv3d test must have input rank == 5"
Matthew Haddon1c00b712021-10-01 15:51:03 +01002566 for t in cleanDtypeFilter:
2567 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002568 # Filter out by rank
2569 if shape is not None and len(shape) != r:
2570 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002571 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002572 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002573
Matthew Haddon74567092021-07-16 15:38:20 +01002574 shapeStr = self.shapeStr(shapeList[0])
2575 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002576
Matthew Haddon74567092021-07-16 15:38:20 +01002577 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2578 argList = []
2579 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002580 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002581 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002582 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002583
Matthew Haddon74567092021-07-16 15:38:20 +01002584 for argStr, args in argList:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002585 if testType == 'positive':
2586 if argStr:
2587 testStr = "{}_{}_{}_{}".format(
2588 opName, shapeStr, typeStr, argStr
2589 )
2590 else:
2591 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
2592 elif testType == 'negative':
Matthew Haddone86fd342021-09-07 16:12:21 +01002593 if argStr:
2594 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2595 opName, error_name, shapeStr, typeStr, argStr
2596 )
2597 else:
2598 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002599
2600 testList.append((opName, testStr, t, error_name, shapeList, args))
2601
2602 if testType == 'positive':
2603 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2604 if "invalid_test_validators" in op:
2605 invalid_test_validators = op["invalid_test_validators"]
2606 clean_testList = []
2607 for test in testList:
2608 for validator_fcn in invalid_test_validators:
2609 remove_test = False
2610 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
2611 remove_test = True
2612 if not remove_test:
2613 clean_testList.append(test)
2614 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002615
2616 return testList
2617
Matthew Haddone86fd342021-09-07 16:12:21 +01002618
2619 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07002620 try:
2621 op = self.TOSA_OP_LIST[opName]
2622 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002623 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002624
2625 # Create a serializer
2626 self.createSerializer(opName, testStr)
2627
Kevin Cheng550ccc52021-03-03 11:21:43 -08002628 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002629 if "error_if_validators" in op:
2630 error_if_validators = op["error_if_validators"]
2631 else:
2632 error_if_validators = None
2633
Kevin Cheng550ccc52021-03-03 11:21:43 -08002634 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002635 num_operands = pCount + cCount
2636
2637 if isinstance(dtype_or_dtypeList, list):
2638 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002639 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002640 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002641 else:
2642 dtypeList = [dtype_or_dtypeList] * (num_operands)
2643
Kevin Cheng93a16282021-08-31 16:14:03 -07002644 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002645 assert (
2646 len(shapeList) == num_operands
2647 ), "shapeList length {} must match number of operands {}".format(
2648 len(shapeList), num_operands
2649 )
2650 assert (
2651 len(dtypeList) == num_operands
2652 ), "dtypeList length {} must match number of operands {}".format(
2653 len(dtypeList), num_operands
2654 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002655
2656 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002657 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002658 except KeyError:
2659 qgen = None
2660
2661 # Build the random tensor operands and the test
2662 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002663
Matthew Haddon1c00b712021-10-01 15:51:03 +01002664 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs)
2665
2666 if qgen is not None:
2667 qinfo = qgen(self, op, dtype_or_dtypeList)
2668 else:
2669 qinfo = None
2670
2671 try:
2672 if error_if_validators is None:
2673 if qinfo is not None:
2674 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2675 else:
2676 resultName = build_fcn(self, op, *tens, *testArgs)
2677 else:
2678 if qinfo is not None:
2679 resultName = build_fcn(self, op, *tens, *testArgs, qinfo, error_if_validators, error_name)
2680 else:
2681 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name)
2682 except TypeError as e:
2683 print(
2684 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
2685 build_fcn, tens, testArgs
2686 )
2687 )
2688 raise e
2689
2690 if resultName is None:
2691 print("Invalid ERROR_IF tests created")
2692
2693 # Save the serialized test
2694 self.serialize("test")
2695
2696
2697 def generate_tensors(self, op, dtypeList, shapeList, testArgs):
2698 pCount, cCount = op["operands"]
2699
2700 tens = []
Jeremy Johnsonef509a42021-09-07 13:59:47 +01002701 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32:
2702 # Make sure the operation does not cause value saturation - where
2703 # the number wraps due to limited number of bits to store the answer
2704 assert (
2705 pCount == 2 and cCount == 0
2706 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
2707
2708 placeholders = []
2709 add = (op["op"] == Op.ADD)
2710 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
2711 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
2712 if add:
2713 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
2714 else:
2715 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
2716
2717 # Work out the saturation limits
2718 max_i32 = (1 << 31)-1
2719 min_i32 = -(1 << 31)
2720 max_arr = np.full(shapeList[1], max_i32)
2721 min_arr = np.full(shapeList[1], min_i32)
2722
2723 # Find how much values exceed the maximum/minimums
2724 sat_max_arr = np.maximum(res_arr - max_arr, 0)
2725 sat_min_arr = np.minimum(res_arr - min_arr, 0)
2726
2727 if not add:
2728 # Swap saturation values and negate values as we need to perform opposite operations
2729 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
2730
2731 # Create new array of unsaturated values by clipping values as needed
2732 b_unsat_arr = b_arr
2733 if (sat_max_arr != 0).any():
2734 # Clip values that cause saturation
2735 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
2736 # Reduce axes in unsaturated tensor to match original tensor
2737 for axis, dim in enumerate(b_arr.shape):
2738 if dim != b_unsat_arr.shape[axis]:
2739 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
2740 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
2741
2742 if (sat_min_arr != 0).any():
2743 # Clip values that cause saturation
2744 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
2745 # Reduce axes in unsaturated tensor to match original tensor
2746 for axis, dim in enumerate(b_arr.shape):
2747 if dim != b_unsat_arr.shape[axis]:
2748 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
2749 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
2750
2751 placeholders.append(
2752 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
2753 )
2754 placeholders.append(
2755 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
2756 )
2757
2758 tens.extend(placeholders)
2759 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
2760 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002761 assert (
2762 pCount == 2 and cCount == 0
2763 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08002764
2765 placeholders = []
2766 for idx, shape in enumerate(shapeList[:]):
2767 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07002768 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002769 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07002770 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002771 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07002772 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002773 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
2774 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002775 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002776 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002777 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07002778 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08002779
2780 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01002781 elif op["op"] == Op.SELECT:
2782 # Set datatype of condition tensor to boolean
2783 dtypeList[0] = DType.BOOL
2784 tens.extend(
2785 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
2786 )
2787 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddon459443c2021-08-23 16:43:13 +01002788 elif op["op"] == Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002789 assert (
2790 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01002791 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002792
2793 placeholders = []
2794
Matthew Haddon459443c2021-08-23 16:43:13 +01002795 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002796 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07002797 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002798 while True:
2799 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
2800 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
2801
2802 if (divisor_arr == 0).any():
2803 continue
2804
Kevin Cheng47315e12021-05-13 17:41:28 -07002805 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002806 continue
2807
2808 break
2809
2810 placeholders.append(
2811 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
2812 )
2813 placeholders.append(
2814 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
2815 )
2816
2817 tens.extend(placeholders)
2818 elif op["op"] == Op.MUL:
2819 assert (
2820 pCount == 2 and cCount == 0
2821 ), "Op.MUL must have 2 placeholders, 0 consts"
2822
2823 if dtypeList[0] == DType.FLOAT:
2824 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
2825 else:
2826 placeholders = []
2827
2828 # Make sure multiply result in int32 range
2829 shift = testArgs[0]
2830 if dtypeList[0] == DType.INT8:
2831 num_bits = 8
2832 elif dtypeList[0] == DType.INT16:
2833 num_bits = 16
2834 elif dtypeList[0] == DType.INT32:
2835 num_bits = 32
2836 else:
2837 raise Exception("OpMul: invalid input dtype")
2838
2839 for idx, shape in enumerate(shapeList[:]):
2840 low = -(2 ** (num_bits - 1))
2841 high = (2 ** (num_bits - 1)) - 1
2842
2843 a_arr = np.int32(
2844 self.rng.integers(low=low, high=high, size=shapeList[0])
2845 )
2846 b_arr = np.int32(
2847 self.rng.integers(low=low, high=high, size=shapeList[1])
2848 )
2849
2850 i = 0
2851 while True:
2852
2853 a_arr_64 = a_arr.astype(np.int64)
2854 b_arr_64 = b_arr.astype(np.int64)
2855
2856 if shift > 0:
2857 rounding = 1 << (shift - 1)
2858 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
2859 else:
2860 result_arr = a_arr_64 * b_arr_64
2861
2862 if (result_arr > -(2 ** 31)).all() and (
2863 result_arr <= ((2 ** 31) - 1)
2864 ).all():
2865 break
2866
2867 i = i + 1
2868 a_arr = a_arr // 2
2869 b_arr = b_arr // 2
2870
2871 placeholders.append(
2872 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
2873 )
2874 placeholders.append(
2875 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
2876 )
2877
2878 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01002879 elif op["op"] == Op.CONCAT:
2880 count = len(shapeList) - self.args.num_const_inputs_concat
2881 if count < 1:
2882 count = 1
2883 if self.args.num_const_inputs_concat == 0:
2884 count = len(shapeList)
2885
2886 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
2887 tens.extend(
2888 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
2889 )
2890 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08002891 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002892 tens.extend(
2893 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
2894 )
2895 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002896
Matthew Haddon1c00b712021-10-01 15:51:03 +01002897 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002898
2899 def createDynamicOpLists(self):
2900
2901 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002902 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002903
Kevin Cheng1533b852021-09-01 12:51:58 -07002904 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002905 testName = "conv2d_{}x{}".format(k[0], k[1])
2906 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2907 self.TOSA_OP_LIST[testName]["filter"] = k
2908 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002909
Kevin Cheng550ccc52021-03-03 11:21:43 -08002910 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2911 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2912 "depthwise_conv2d_TEMPLATE"
2913 ].copy()
2914 self.TOSA_OP_LIST[testName]["filter"] = k
2915 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002916
Kevin Cheng550ccc52021-03-03 11:21:43 -08002917 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2918 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2919 "transpose_conv2d_TEMPLATE"
2920 ].copy()
2921 self.TOSA_OP_LIST[testName]["filter"] = k
2922 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002923
Kevin Cheng1533b852021-09-01 12:51:58 -07002924 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2925 for k in KERNELS_3D:
2926 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2927 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2928 self.TOSA_OP_LIST[testName]["filter"] = k
2929 self.TOSA_OP_LIST[testName]["template"] = False
2930
Eric Kunzee5e26762020-10-13 16:11:07 -07002931 # Delete any templates after having created any dynamic ops
2932 # This is a two-pass operation because it's bad practice to delete
2933 # keys from dictionaries while iterating
2934 keyList = []
2935 for k in self.TOSA_OP_LIST:
2936 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002937 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07002938 keyList.append(k)
2939 continue
2940 except KeyError:
2941 pass
2942
2943 for k in keyList:
2944 del self.TOSA_OP_LIST[k]
2945
2946 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002947 """Fill in default fields for ops if they aren't already specified.
2948 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002949 for op in self.TOSA_OP_LIST:
2950
2951 # Required fields
2952 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002953 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002954 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002955 raise Exception(
2956 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2957 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002958
2959 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002960 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002961 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002962 raise Exception(
2963 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2964 op
2965 )
2966 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002967
2968 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002969 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002970 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002971 raise Exception(
2972 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2973 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002974
2975 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002976 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002977 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002978 raise Exception(
2979 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2980 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002981
2982 # Put in default rank range, if missing
2983 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002984 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002985 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002986 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002987
2988 # Tensor operator list
2989 # 'op': op name
2990 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002991 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2992 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002993 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2994 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08002995 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002996
Kevin Cheng550ccc52021-03-03 11:21:43 -08002997 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
2998 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002999
Kevin Cheng550ccc52021-03-03 11:21:43 -08003000 TYPE_BOOL = [DType.BOOL]
3001 TYPE_FI32 = [DType.FLOAT, DType.INT32]
3002 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
3003 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003004
Kevin Cheng550ccc52021-03-03 11:21:43 -08003005 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07003006
Kevin Cheng1533b852021-09-01 12:51:58 -07003007 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003008 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003009 [DType.INT8, DType.INT8, DType.INT32],
3010 [DType.INT16, DType.INT8, DType.INT48],
3011 DType.FLOAT,
3012 ]
3013
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003014 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003015
3016 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003017 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003018 "argmax": {
3019 "op": Op.ARGMAX,
3020 "operands": (1, 0),
3021 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3022 "types": TYPE_NARROW_INT_FP,
3023 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003024 "avg_pool2d": {
3025 "op": Op.AVG_POOL2D,
3026 "operands": (1, 0),
3027 "rank": (4, 4),
3028 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
3029 "qgen": TosaQuantGen.qgUnary,
3030 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003031 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
Jared Smolens573ecd42021-03-04 15:24:10 -08003032 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003033 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003034 "conv2d_TEMPLATE": {
3035 "op": Op.CONV2D,
3036 "operands": (1, 2),
3037 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01003038 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003039 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003040 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003041 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003042 "template": True,
3043 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003044 # Templated operator. Filled in by createDynamicOpLists
3045 "conv3d_TEMPLATE": {
3046 "op": Op.CONV3D,
3047 "operands": (1, 2),
3048 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01003049 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07003050 "qgen": TosaQuantGen.qgConv,
3051 "types": TYPE_CONV,
3052 "template": True,
3053 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003054 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003055 "depthwise_conv2d_TEMPLATE": {
3056 "op": Op.DEPTHWISE_CONV2D,
3057 "operands": (1, 2),
3058 "filter": [1, 1],
3059 "rank": (4, 4),
3060 "build_fcn": (
3061 build_depthwise_conv2d,
3062 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01003063 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003064 ),
3065 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003066 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003067 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003068 "template": True,
3069 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003070 "fully_connected": {
3071 "op": Op.FULLY_CONNECTED,
3072 "operands": (1, 2),
3073 "rank": (2, 2),
3074 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
3075 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003076 "types": TYPE_CONV,
Jared Smolens573ecd42021-03-04 15:24:10 -08003077 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003078 "matmul": {
3079 "op": Op.MATMUL,
3080 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003081 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08003082 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
3083 "qgen": TosaQuantGen.qgMatmul,
3084 "types": TYPE_NARROW_INT_FP,
3085 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003086 "max_pool2d": {
3087 "op": Op.MAX_POOL2D,
3088 "operands": (1, 0),
3089 "rank": (4, 4),
3090 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
3091 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003092 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
Jared Smolens573ecd42021-03-04 15:24:10 -08003093 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003094 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003095 "transpose_conv2d_TEMPLATE": {
3096 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003097 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003098 "rank": (4, 4),
3099 "build_fcn": (
3100 build_transpose_conv2d,
3101 TosaTensorGen.tgTransposeConv2D,
3102 TosaArgGen.agTransposeConv2D,
3103 ),
3104 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003105 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003106 "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003107 "template": True,
3108 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003109 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003110 "clamp": {
3111 "op": Op.CLAMP,
3112 "operands": (1, 0),
3113 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
3114 "types": TYPE_NARROW_INT_FP,
3115 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003116 "sigmoid": {
3117 "op": Op.SIGMOID,
3118 "operands": (1, 0),
3119 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
3120 "types": TYPE_FP,
3121 },
3122 "tanh": {
3123 "op": Op.TANH,
3124 "operands": (1, 0),
3125 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
3126 "types": TYPE_FP,
3127 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003128 # Elementwise Binary Operators
3129 "add": {
3130 "op": Op.ADD,
3131 "operands": (2, 0),
3132 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3133 "types": TYPE_FI32,
3134 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003135 "arithmetic_right_shift": {
3136 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3137 "operands": (2, 0),
3138 "build_fcn": (
3139 build_arithmetic_right_shift,
3140 TosaTensorGen.tgBroadcastFuzz,
3141 TosaArgGen.agArithmeticRightShift,
3142 ),
3143 "types": TYPE_INT,
3144 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003145 "bitwise_and": {
3146 "op": Op.BITWISE_AND,
3147 "operands": (2, 0),
3148 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3149 "types": TYPE_INT,
3150 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003151 "bitwise_or": {
3152 "op": Op.BITWISE_OR,
3153 "operands": (2, 0),
3154 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3155 "types": TYPE_INT,
3156 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003157 "bitwise_xor": {
3158 "op": Op.BITWISE_XOR,
3159 "operands": (2, 0),
3160 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3161 "types": TYPE_INT,
3162 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003163 "intdiv": {
3164 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003165 "operands": (2, 0),
3166 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3167 "types": [DType.INT32],
3168 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003169 "logical_and": {
3170 "op": Op.LOGICAL_AND,
3171 "operands": (2, 0),
3172 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3173 "types": TYPE_BOOL,
3174 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003175 "logical_left_shift": {
3176 "op": Op.LOGICAL_LEFT_SHIFT,
3177 "operands": (2, 0),
3178 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3179 "types": TYPE_INT,
3180 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003181 "logical_right_shift": {
3182 "op": Op.LOGICAL_RIGHT_SHIFT,
3183 "operands": (2, 0),
3184 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3185 "types": TYPE_INT,
3186 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003187 "logical_or": {
3188 "op": Op.LOGICAL_OR,
3189 "operands": (2, 0),
3190 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3191 "types": TYPE_BOOL,
3192 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003193 "logical_xor": {
3194 "op": Op.LOGICAL_XOR,
3195 "operands": (2, 0),
3196 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3197 "types": TYPE_BOOL,
3198 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003199 "maximum": {
3200 "op": Op.MAXIMUM,
3201 "operands": (2, 0),
3202 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3203 "types": TYPE_FI32,
3204 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003205 "minimum": {
3206 "op": Op.MINIMUM,
3207 "operands": (2, 0),
3208 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3209 "types": TYPE_FI32,
3210 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003211 "mul": {
3212 "op": Op.MUL,
3213 "operands": (2, 0),
3214 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
3215 "types": TYPE_INT_FP,
3216 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003217 "pow": {
3218 "op": Op.POW,
3219 "operands": (2, 0),
3220 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
3221 "types": TYPE_FP,
3222 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003223 "sub": {
3224 "op": Op.SUB,
3225 "operands": (2, 0),
3226 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3227 "types": TYPE_FI32,
3228 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003229 "table": {
3230 "op": Op.TABLE,
3231 # Use the automatic generation functions to create the input array
3232 # but create the table tensor in the build function, as it may be
3233 # a different type from the input
3234 "operands": (1, 0),
3235 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003236 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08003237 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003238 # Elementwise Unary operators
3239 "abs": {
3240 "op": Op.ABS,
3241 "operands": (1, 0),
3242 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3243 "types": TYPE_FI32,
3244 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003245 "bitwise_not": {
3246 "op": Op.BITWISE_NOT,
3247 "operands": (1, 0),
3248 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3249 "types": TYPE_INT,
3250 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003251 "ceil": {
3252 "op": Op.CEIL,
3253 "operands": (1, 0),
3254 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3255 "types": TYPE_FP,
3256 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003257 "clz": {
3258 "op": Op.CLZ,
3259 "operands": (1, 0),
3260 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3261 "types": [DType.INT32],
3262 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003263 "exp": {
3264 "op": Op.EXP,
3265 "operands": (1, 0),
3266 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3267 "types": TYPE_FP,
3268 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003269 "floor": {
3270 "op": Op.FLOOR,
3271 "operands": (1, 0),
3272 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3273 "types": TYPE_FP,
3274 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003275 "log": {
3276 "op": Op.LOG,
3277 "operands": (1, 0),
3278 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3279 "types": TYPE_FP,
3280 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003281 "logical_not": {
3282 "op": Op.LOGICAL_NOT,
3283 "operands": (1, 0),
3284 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3285 "types": TYPE_BOOL,
3286 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003287 "negate": {
3288 "op": Op.NEGATE,
3289 "operands": (1, 0),
3290 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3291 "qgen": TosaQuantGen.qgUnary,
3292 "types": TYPE_INT_FP,
3293 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003294 "reciprocal": {
3295 "op": Op.RECIPROCAL,
3296 "operands": (1, 0),
3297 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3298 "types": TYPE_FP,
3299 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003300 "rsqrt": {
3301 "op": Op.RSQRT,
3302 "operands": (1, 0),
3303 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3304 "types": TYPE_FP,
3305 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003306 # Elementwise Ternary operators
3307 "select": {
3308 "op": Op.SELECT,
3309 "operands": (3, 0),
3310 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
3311 "types": TYPE_FIB,
3312 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003313 # Comparison operators
3314 "equal": {
3315 "op": Op.EQUAL,
3316 "operands": (2, 0),
3317 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3318 "types": TYPE_FI32,
3319 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003320 "greater_equal": {
3321 "op": Op.GREATER_EQUAL,
3322 "operands": (2, 0),
3323 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3324 "types": TYPE_FI32,
3325 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003326 "greater": {
3327 "op": Op.GREATER,
3328 "operands": (2, 0),
3329 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3330 "types": TYPE_FI32,
3331 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003332 # Reduction operators
3333 "reduce_all": {
3334 "op": Op.REDUCE_ALL,
3335 "operands": (1, 0),
3336 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3337 "types": TYPE_BOOL,
3338 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003339 "reduce_any": {
3340 "op": Op.REDUCE_ANY,
3341 "operands": (1, 0),
3342 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3343 "types": TYPE_BOOL,
3344 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003345 "reduce_max": {
3346 "op": Op.REDUCE_MAX,
3347 "operands": (1, 0),
3348 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3349 "types": TYPE_INT_FP,
3350 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003351 "reduce_min": {
3352 "op": Op.REDUCE_MAX,
3353 "operands": (1, 0),
3354 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3355 "types": TYPE_INT_FP,
3356 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003357 "reduce_product": {
3358 "op": Op.REDUCE_PRODUCT,
3359 "operands": (1, 0),
3360 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3361 "types": TYPE_FP,
3362 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003363 "reduce_sum": {
3364 "op": Op.REDUCE_SUM,
3365 "operands": (1, 0),
3366 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3367 "types": TYPE_FI32,
3368 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003369 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003370 "concat": {
3371 "op": Op.CONCAT,
3372 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01003373 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003374 "types": TYPE_FIB,
3375 },
3376 "pad": {
3377 "op": Op.PAD,
3378 "operands": (1, 0),
3379 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
3380 "qgen": TosaQuantGen.qgPad,
3381 "types": TYPE_FIB,
3382 },
3383 "reshape": {
3384 "op": Op.RESHAPE,
3385 "operands": (1, 0),
3386 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
3387 "types": TYPE_FIB,
3388 },
3389 "reverse": {
3390 "op": Op.REVERSE,
3391 "operands": (1, 0),
3392 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3393 "types": TYPE_FIB,
3394 },
3395 "slice": {
3396 "op": Op.SLICE,
3397 "operands": (1, 0),
3398 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
3399 "types": TYPE_FIB,
3400 },
3401 "tile": {
3402 "op": Op.TILE,
3403 "operands": (1, 0),
3404 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
3405 "types": TYPE_FIB,
3406 },
3407 "transpose": {
3408 "op": Op.TRANSPOSE,
3409 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003410 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003411 "build_fcn": (
3412 build_transpose,
3413 TosaTensorGen.tgBasic,
3414 TosaArgGen.agTranspose,
3415 ),
3416 "types": TYPE_FIB,
3417 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003418 # Data nodes
3419 "const": {
3420 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003421 "operands": (0, 1),
3422 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08003423 "types": TYPE_FIB,
3424 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003425 "identity": {
3426 "op": Op.IDENTITY,
3427 "operands": (1, 0),
3428 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3429 "types": TYPE_FIB,
3430 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003431 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003432 "gather": {
3433 "op": Op.GATHER,
3434 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3435 "operands": (1, 0),
3436 "rank": (3, 3),
3437 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
3438 "types": TYPE_INT_FP,
3439 },
3440 "scatter": {
3441 "op": Op.SCATTER,
3442 # Only specify 'values_in' tensor here.
3443 #'indices' and 'input' are generated in op building stage
3444 "operands": (2, 0),
3445 "rank": (3, 3),
3446 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
3447 "types": TYPE_INT_FP,
3448 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003449 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003450 "resize": {
3451 "op": Op.RESIZE,
3452 "operands": (1, 0),
3453 "rank": (4, 4),
3454 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
3455 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01003456 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
3457 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
3458 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
Matthew Haddon848efb42021-09-09 12:30:53 +01003459 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
3460 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003461 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003462 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003463 "cast": {
3464 "op": Op.CAST,
3465 "operands": (1, 0),
3466 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
3467 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
3468 },
3469 "rescale": {
3470 "op": Op.RESCALE,
3471 "operands": (1, 0),
3472 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003473 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Kevin Cheng550ccc52021-03-03 11:21:43 -08003474 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003475 # Custom
3476 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003477 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003478 # Two varients of cond_if, one that generates one of two constant tensors (no
3479 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3480 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003481 "cond_if_const": {
3482 "op": Op.COND_IF,
3483 "operands": (0, 2),
3484 "build_fcn": (
3485 build_cond_if_const,
3486 TosaTensorGen.tgBasic,
3487 TosaArgGen.agCondIf,
3488 ),
3489 "types": [DType.BOOL],
3490 },
3491 "cond_if_binary": {
3492 "op": Op.COND_IF,
3493 "operands": (2, 0),
3494 "build_fcn": (
3495 build_cond_if_binary,
3496 TosaTensorGen.tgBasic,
3497 TosaArgGen.agCondIf,
3498 ),
3499 "types": TYPE_FI32,
3500 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003501 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003502 "while_loop": {
3503 "op": Op.WHILE_LOOP,
3504 "operands": (0, 1),
3505 "build_fcn": (
3506 build_while_loop,
3507 TosaTensorGen.tgBasic,
3508 TosaArgGen.agWhileLoop,
3509 ),
3510 "types": [DType.INT32],
3511 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003512 }
3513
Kevin Cheng550ccc52021-03-03 11:21:43 -08003514
Eric Kunzee5e26762020-10-13 16:11:07 -07003515class OutputShaper:
3516 # Methods in this class compute the expected output shape and datatype
3517 # for common classes of operations
3518 def __init__(self):
3519 pass
3520
3521 # These methods return arguments that can be used for
3522 # creating a new output tensor
3523 @staticmethod
3524 def binaryBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003525 assert len(a.shape) == len(b.shape)
3526 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003527
3528 shape = []
3529 for i in range(len(a.shape)):
3530 if a.shape[i] == 1:
3531 shape.append(b.shape[i])
3532 else:
3533 shape.append(a.shape[i])
3534
Kevin Cheng550ccc52021-03-03 11:21:43 -08003535 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003536
3537 @staticmethod
3538 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003539 assert len(a.shape) == len(b.shape)
3540 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003541
3542 shape = []
3543 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003544 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07003545 shape.append(a.shape[i])
3546
Kevin Cheng550ccc52021-03-03 11:21:43 -08003547 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003548
3549 @staticmethod
3550 def unaryOp(ser, a):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003551 return ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003552
3553 @staticmethod
3554 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003555 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
3556 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003557
3558 shape = []
3559 for i in range(len(a.shape)):
3560 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
3561
Kevin Cheng550ccc52021-03-03 11:21:43 -08003562 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003563
3564 @staticmethod
3565 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003566 assert len(a.shape) == len(b.shape)
3567 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003568
3569 # Do broadcast
3570 shape = []
3571 for i in range(len(a.shape)):
3572 if a.shape[i] == 1:
3573 shape.append(b.shape[i])
3574 else:
3575 shape.append(a.shape[i])
3576
3577 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08003578 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07003579
3580 @staticmethod
3581 def reduceOp(ser, a, axis):
3582
3583 shape = a.shape.copy()
3584
3585 shape[axis] = 1
3586
Kevin Cheng550ccc52021-03-03 11:21:43 -08003587 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003588
3589 @staticmethod
3590 def argmaxOp(ser, a, axis):
3591 shape = a.shape.copy()
3592 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003593 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07003594
3595 @staticmethod
3596 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
3597
3598 # IFM: NHWC
3599 # Filter: OHWI
3600 # OFM: NHWC
3601
3602 if len(padding) == 2:
3603 # Expand padding to 4 parameters in the case of transpose_conv2d
3604 # From H,W to T,B,L,R
3605 padding = [padding[0], padding[0], padding[1], padding[1]]
3606
Kevin Cheng550ccc52021-03-03 11:21:43 -08003607 h = (
3608 ifm.shape[1]
3609 - filter.shape[1]
3610 - (filter.shape[1] - 1) * (dilations[0] - 1)
3611 + padding[0]
3612 + padding[1]
3613 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003614
Kevin Cheng550ccc52021-03-03 11:21:43 -08003615 w = (
3616 ifm.shape[2]
3617 - filter.shape[2]
3618 - (filter.shape[2] - 1) * (dilations[1] - 1)
3619 + padding[2]
3620 + padding[3]
3621 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003622
Eric Kunzee5e26762020-10-13 16:11:07 -07003623 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
3624
Kevin Cheng3a478572021-01-22 17:21:02 -08003625 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003626 out_dtype = DType.INT32
3627 elif ifm.dtype == DType.INT16:
3628 out_dtype = DType.INT48
3629 elif ifm.dtype == DType.FLOAT:
3630 out_dtype = DType.FLOAT
3631 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003632 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003633
Kevin Cheng550ccc52021-03-03 11:21:43 -08003634 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003635
3636 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -07003637 def conv3dOp(ser, ifm, filter, strides, padding, dilations):
3638
3639 # IFM: NDHWC
3640 # Filter: ODHWI
3641 # OFM: NDHWC
3642
3643 d = (
3644 ifm.shape[1]
3645 - filter.shape[1]
3646 - (filter.shape[1] - 1) * (dilations[0] - 1)
3647 + padding[0]
3648 + padding[1]
3649 ) // strides[0] + 1
3650
3651 h = (
3652 ifm.shape[2]
3653 - filter.shape[2]
3654 - (filter.shape[2] - 1) * (dilations[1] - 1)
3655 + padding[2]
3656 + padding[3]
3657 ) // strides[1] + 1
3658
3659 w = (
3660 ifm.shape[3]
3661 - filter.shape[3]
3662 - (filter.shape[3] - 1) * (dilations[2] - 1)
3663 + padding[4]
3664 + padding[5]
3665 ) // strides[2] + 1
3666
3667 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
3668
3669 if ifm.dtype == DType.INT8:
3670 out_dtype = DType.INT32
3671 elif ifm.dtype == DType.INT16:
3672 out_dtype = DType.INT48
3673 elif ifm.dtype == DType.FLOAT:
3674 out_dtype = DType.FLOAT
3675 else:
3676 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
3677
3678 return ser.addOutput(ofm_shape, out_dtype)
3679
3680 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -07003681 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
3682 # IFM: NHWC
3683 # Filter: HWCM
3684 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08003685 h = (
3686 ifm.shape[1]
3687 - filter.shape[0]
3688 - (filter.shape[0] - 1) * (dilations[0] - 1)
3689 + padding[0]
3690 + padding[1]
3691 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003692
Kevin Cheng550ccc52021-03-03 11:21:43 -08003693 w = (
3694 ifm.shape[2]
3695 - filter.shape[1]
3696 - (filter.shape[1] - 1) * (dilations[1] - 1)
3697 + padding[2]
3698 + padding[3]
3699 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003700
Eric Kunzee5e26762020-10-13 16:11:07 -07003701 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
3702
Kevin Cheng3a478572021-01-22 17:21:02 -08003703 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003704 out_dtype = DType.INT32
3705 elif ifm.dtype == DType.INT16:
3706 out_dtype = DType.INT48
3707 elif ifm.dtype == DType.FLOAT:
3708 out_dtype = DType.FLOAT
3709 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003710 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003711
Kevin Cheng550ccc52021-03-03 11:21:43 -08003712 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003713
3714 @staticmethod
3715 def pool2dOp(ser, ifm, kernel, stride, pad):
3716 # input: NHWC
3717 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
3718 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
3719
Eric Kunzee5e26762020-10-13 16:11:07 -07003720 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003721 return ser.addOutput(ofm_shape, ifm.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003722
3723 @staticmethod
3724 def fullyConnectedOp(ser, input, filter):
3725 # input: N, IC
3726 # filter: OC, IC
3727 # output: N, OC
3728
3729 output_shape = [input.shape[0], filter.shape[0]]
3730
Kevin Cheng3a478572021-01-22 17:21:02 -08003731 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003732 out_dtype = DType.INT32
3733 elif input.dtype == DType.INT16:
3734 out_dtype = DType.INT48
3735 elif input.dtype == DType.FLOAT:
3736 out_dtype = DType.FLOAT
3737 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003738 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003739
Kevin Cheng550ccc52021-03-03 11:21:43 -08003740 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003741
3742 @staticmethod
3743 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07003744 # a: N, H, C
3745 # b: N, C, W
3746 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07003747
Kevin Cheng2d60f002021-06-09 14:18:32 -07003748 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003749
Kevin Cheng3a478572021-01-22 17:21:02 -08003750 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003751 out_dtype = DType.INT32
3752 elif a.dtype == DType.INT16:
3753 out_dtype = DType.INT48
3754 elif a.dtype == DType.FLOAT:
3755 out_dtype = DType.FLOAT
3756 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003757 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003758
Kevin Cheng550ccc52021-03-03 11:21:43 -08003759 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003760
3761 @staticmethod
Matthew Haddon818ab902021-07-27 09:12:49 +01003762 def concatOp(ser, axis, *a):
3763 input1 = a[0]
3764 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07003765
Matthew Haddon818ab902021-07-27 09:12:49 +01003766 output_shape = input1.shape.copy()
Eric Kunzee5e26762020-10-13 16:11:07 -07003767
Matthew Haddon818ab902021-07-27 09:12:49 +01003768 output_shape[axis] = input1.shape[axis]
3769
3770 for tensor in remaining_inputs:
3771 output_shape[axis] += tensor.shape[axis]
3772
3773 return ser.addOutput(output_shape, input1.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003774
3775 @staticmethod
3776 def padOp(ser, a, padding):
3777
3778 output_shape = a.shape.copy()
3779
3780 for i in range(len(output_shape)):
3781 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
3782
Kevin Cheng550ccc52021-03-03 11:21:43 -08003783 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003784
3785 @staticmethod
3786 def reshapeOp(ser, a, shape):
3787 output_shape = shape.copy()
3788
3789 totalElements = 1
3790 for i in a.shape:
3791 totalElements *= i
3792
3793 # If there are any -1 elements, figure out what that dimension must be
3794 totalOutputElements = 1
3795 for i in output_shape:
3796 if i != -1:
3797 totalOutputElements *= i
3798
3799 # And fill it in
3800 for i in range(len(output_shape)):
3801 if output_shape[i] == -1:
3802 output_shape[i] = totalElements // totalOutputElements
3803
Kevin Cheng550ccc52021-03-03 11:21:43 -08003804 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003805
3806 @staticmethod
3807 def sliceOp(ser, a, begin, size):
3808
3809 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003810 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003811
3812 @staticmethod
3813 def tileOp(ser, a, multiples):
3814
3815 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003816 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003817
3818 for i in range(len(output_shape)):
3819 output_shape[i] = a.shape[i] * multiples[i]
3820
Kevin Cheng550ccc52021-03-03 11:21:43 -08003821 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003822
3823 @staticmethod
3824 def transposeOp(ser, a, perms):
3825 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003826 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003827
3828 for i in range(len(output_shape)):
3829 output_shape[i] = a.shape[perms[i]]
3830
Kevin Cheng550ccc52021-03-03 11:21:43 -08003831 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003832
3833 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08003834 def gatherOp(ser, values, indices):
3835 assert len(values.shape) == 3
3836 assert len(indices.shape) == 2
3837 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07003838
Kevin Cheng77d0f762020-11-24 10:26:32 -08003839 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
3840
Kevin Cheng550ccc52021-03-03 11:21:43 -08003841 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08003842
3843 @staticmethod
3844 def scatterOp(ser, values_in, indices, input):
3845 assert len(values_in.shape) == 3
3846 assert len(indices.shape) == 2
3847 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08003848 assert values_in.shape[0] == indices.shape[0] # N
3849 assert input.shape[1] == indices.shape[1] # W
3850 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08003851
3852 output_shape = values_in.shape
3853
Kevin Cheng550ccc52021-03-03 11:21:43 -08003854 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003855
3856 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003857 def tableOp(ser, input, table_dtype):
3858 # Same shape as the input, but dtype dependent on table dtype
3859 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
3860 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
3861 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003862
3863 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08003864 def resizeOp(
3865 ser,
3866 input,
3867 mode,
3868 stride,
3869 offset,
3870 shift,
3871 stride_fp,
3872 offset_fp,
3873 output_dims,
3874 input_dtype,
3875 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01003876 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08003877 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01003878 if error_name == ErrorIf.WrongRank:
3879 output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
3880 else:
3881 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003882
Kevin Cheng550ccc52021-03-03 11:21:43 -08003883 return ser.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003884
3885 @staticmethod
3886 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003887 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003888
3889 @staticmethod
3890 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08003891 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003892 out_dtype = DType.INT32
3893 elif ifm.dtype == DType.INT16:
3894 out_dtype = DType.INT48
3895 elif ifm.dtype == DType.FLOAT:
3896 out_dtype = DType.FLOAT
3897 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003898 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003899
Kevin Cheng550ccc52021-03-03 11:21:43 -08003900 return ser.addOutput(output_shape, out_dtype)