blob: f5f7fffcd97a6a2843c142e3161be11d0baccabc [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
48DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080049Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070050ResizeMode = tosa.ResizeMode.ResizeMode()
51
52class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
54
Eric Kunzee5e26762020-10-13 16:11:07 -070055 def __init__(self):
56 pass
57
58 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010059 def getQinfo(testGen, dtype):
60 if dtype == DType.INT8:
61 return testGen.randInt(-128, 128)
62 if dtype == DType.UINT8:
63 return testGen.randInt(0, 256)
64 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070065
66 @staticmethod
67 def qgUnary(testGen, op, dtype):
68 qinfo = ts.TosaSerializerQuantInfo()
Kevin Chengacb550f2021-06-29 15:32:19 -070069 qinfo.UnaryQuantInfo(
70 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
71 )
Eric Kunzee5e26762020-10-13 16:11:07 -070072 return qinfo
73
74 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010075 def qgConv(testGen, op, dtype_or_dtypeList):
Eric Kunzee5e26762020-10-13 16:11:07 -070076 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010077 if isinstance(dtype_or_dtypeList, list):
78 # a list of [input, weights, accumulator] dtypes
79 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070080 else:
Les Bell30e46802021-07-23 09:43:31 +010081 # an int, [input, weights, accumulator] dtypes are the same
82 dtypeList = [dtype_or_dtypeList] * 3
83 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
84 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
85 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -070086 return qinfo
87
88 @staticmethod
89 def qgMatmul(testGen, op, dtype):
90 qinfo = ts.TosaSerializerQuantInfo()
Kevin Chengacb550f2021-06-29 15:32:19 -070091 qinfo.MatMulQuantInfo(
92 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
93 )
Eric Kunzee5e26762020-10-13 16:11:07 -070094 return qinfo
95
96 @staticmethod
97 def qgPad(testGen, op, dtype):
98 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010099 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700100 return qinfo
101
102 @staticmethod
103 def computeMultiplierAndShift(scaleFp, scale32):
104 # Derived from computeMultiplierAndShiftTosaScale32
105 # Provide a floating-point scaling factor and the scale32 parameter
106 # to compute the multiplier and shift
107
108 if scale32:
109 scaleBits = 31
110 else:
111 scaleBits = 15
112
113 m, shift = math.frexp(scaleFp)
114
115 if scaleFp < 0.0:
116 m = -m
117
118 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800119 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700120
121 if multiplier == (1 << scaleBits):
122 multiplier = multiplier // 2
123 shift = shift + 1
124
125 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100126 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
127
128 # Adjust multiplier such that shift is in allowed value range.
129 if shift == 0:
130 multiplier = multiplier // 4
131 shift = shift + 2
132 elif shift == 1:
133 multiplier = multiplier // 2
134 shift = shift + 1
135 elif shift == 63:
136 multiplier = multiplier * 2
137 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700138
Kevin Cheng550ccc52021-03-03 11:21:43 -0800139 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100140 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700141
142 return multiplier, shift
143
144
Kevin Cheng550ccc52021-03-03 11:21:43 -0800145class TosaTensorGen:
146 """Tensor generators create a shape list for the placeholder and const tensor
147 data operands for the operator. The actual random data is generated separately for each test."""
148
Eric Kunzee5e26762020-10-13 16:11:07 -0700149 def __init__(self):
150 pass
151
152 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100153 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800154 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700155 shape = testGen.makeShape(rank)
156
157 shape_list = []
158 for i in range(pl + const):
159 shape_list.append(shape.copy())
160
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100161 if error_name == ErrorIf.RankMismatch:
162 if rank == 1 and i != 1:
163 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
164 elif i != 1:
165 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
166
Eric Kunzee5e26762020-10-13 16:11:07 -0700167 return shape_list
168
169 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100170 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800171 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700172
Matthew Haddon848efb42021-09-09 12:30:53 +0100173 if error_name != ErrorIf.WrongRank:
174 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700175
176 shape = testGen.makeShape(rank)
177
178 # Constrict the batch size?
179 if testGen.args.max_batch_size:
180 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
181
182 shape_list = []
183 for i in range(pl + const):
184 shape_list.append(shape.copy())
185
186 return shape_list
187
188 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100189 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800190 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800191
Kevin Cheng550ccc52021-03-03 11:21:43 -0800192 assert pl == 2
193 assert const == 0
194 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800195
196 values_in_shape = testGen.makeShape(rank)
197
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100198 # ignore max batch size if target shape is set
199 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800200 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
201
Kevin Cheng550ccc52021-03-03 11:21:43 -0800202 W = testGen.randInt(
203 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
204 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100205 # Constrict W if one dimension is too large to keep tensor size reasonable
206 if max(values_in_shape) > 5000:
207 W = testGen.randInt(0, 16)
208
Kevin Cheng77d0f762020-11-24 10:26:32 -0800209 input_shape = [values_in_shape[0], W, values_in_shape[2]]
210
211 shape_list = []
212 shape_list.append(values_in_shape.copy())
213 shape_list.append(input_shape.copy())
214
215 return shape_list
216
217 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100218 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700219 shape = testGen.makeShape(rank)
220
Kevin Cheng550ccc52021-03-03 11:21:43 -0800221 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700222
223 shape_list = []
224
225 # Choose one of the inputs to broadcast
226 bcast_idx = testGen.randInt(0, pl + const)
227 for i in range(pl + const):
228 shape_bcast = shape.copy()
229
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100230 if error_name == ErrorIf.RankMismatch:
231 bcast_idx = -1 # Turn off broadcast because we are not testing it
232 if rank == 1 and i != 1:
233 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
234 elif i != 1:
235 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
236
Eric Kunzee5e26762020-10-13 16:11:07 -0700237 # If the chosen input, pick a random index to broadcast
238 if i == bcast_idx:
239 fuzz_idx = testGen.randInt(0, rank)
240 shape_bcast[fuzz_idx] = 1
241
242 shape_list.append(shape_bcast)
243
244 return shape_list
245
246 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100247 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800248 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700249
Kevin Cheng550ccc52021-03-03 11:21:43 -0800250 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700251
252 # IFM dimensions are NHWC
253 ifm_shape = testGen.makeShape(rank)
254
255 # Constrict the batch size?
256 if testGen.args.max_batch_size:
257 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
258
259 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800260 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700261
262 # Generate a random OFM depth
263 ofm_depth = testGen.makeShape(1)[0]
264
265 # The filter dimensions are OHWI
266 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
267
268 # The bias is OC
269 bias_shape = np.asarray([ofm_depth])
270
271 return [ifm_shape, filter_shape, bias_shape]
272
273 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100274 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700275 pl, const = op["operands"]
276
277 assert rank == 5
278
279 # IFM dimensions are NDHWC
280 ifm_shape = testGen.makeShape(rank)
281
282 # Constrict the batch size?
283 if testGen.args.max_batch_size:
284 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
285
286 # Get the filter depth/height/width from the operator parameters
287 filter_dhw = op["filter"]
288
289 # Generate a random OFM channel
290 ofm_channel = testGen.makeShape(1)[0]
291
292 # The filter dimensions are ODHWI
293 filter_shape = np.asarray(
294 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
295 )
296
297 # The bias is OC
298 bias_shape = np.asarray([ofm_channel])
299
300 return [ifm_shape, filter_shape, bias_shape]
301
302 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100303 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800304 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700305
Kevin Cheng550ccc52021-03-03 11:21:43 -0800306 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700307
308 # IFM dimensions are NHWC
309 ifm_shape = testGen.makeShape(rank)
310
311 # Constrict the batch size?
312 if testGen.args.max_batch_size:
313 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
314
315 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800316 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700317
318 # Generate a random OFM depth
319 ofm_depth = testGen.makeShape(1)[0]
320
321 # The filter dimensions are OHWI
322 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
323
Kevin Cheng989cb052021-04-28 16:29:44 -0700324 # The bias is OC
325 bias_shape = np.asarray([ofm_depth])
326
327 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700328
329 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100330 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800331 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700332
Kevin Cheng550ccc52021-03-03 11:21:43 -0800333 assert rank == 4
334 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700335
336 # IFM dimensions are NHWC
337 ifm_shape = testGen.makeShape(rank)
338
339 # Constrict the batch size?
340 if testGen.args.max_batch_size:
341 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
342
343 # Get the filter height/width from the operator parameters
344 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800345 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700346
347 # Generate a random OFM depth, but don't let it get too big because
348 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800349 filter_m = (
350 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
351 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700352
353 # The filter dimensions are HWCM
354 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
355
356 # The bias is M * C
357 bias_shape = np.asarray([ifm_shape[3] * filter_m])
358
359 return [ifm_shape, filter_shape, bias_shape]
360
361 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100362 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800363 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700364
Kevin Cheng550ccc52021-03-03 11:21:43 -0800365 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700366
367 input_shape = testGen.makeShape(rank)
Kevin Chengacb550f2021-06-29 15:32:19 -0700368 filter_oc = testGen.rng.integers(
369 low=testGen.args.tensor_shape_range[0],
370 high=testGen.args.tensor_shape_range[1],
371 size=1,
372 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700373 filter_shape = np.asarray([filter_oc, input_shape[1]])
374
375 bias_shape = np.asarray([filter_oc])
376
377 return [input_shape, filter_shape, bias_shape]
378
379 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100380 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800381 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700382
Kevin Cheng2d60f002021-06-09 14:18:32 -0700383 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800384 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700385
386 a_shape = testGen.makeShape(rank)
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100387 # Get a random number for b_oc even if target shape is defined
388 b_oc = np.int32(
389 testGen.rng.integers(
390 low=testGen.args.tensor_shape_range[0],
391 high=testGen.args.tensor_shape_range[1],
392 size=1,
393 )
394 )[0]
395 # If N or H is large let b_oc be 1 to reduce output tensor size
396 if max(a_shape) > 1000:
397 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700398
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100399 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700400 return [a_shape, b_shape]
401
Matthew Haddon818ab902021-07-27 09:12:49 +0100402 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100403 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100404 pl, const = opName["operands"]
405 shape = testGen.makeShape(rank)
406
407 # Create extra tensors to concat.
408 # Take into account value of pl when getting maximum number of concats
409 num_tensors = testGen.randInt(0, 4)
410 shape_list = []
411 for i in range(pl + const + num_tensors):
412 shape_list.append(shape.copy())
413
414 return shape_list
415
416 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100417 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100418 # Split concat shape along axis to allow for multiple const inputs
419 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100420 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddon818ab902021-07-27 09:12:49 +0100421 return shapeList
422
Jeremy Johnson960985a2021-10-06 10:58:14 +0100423 # Create copy of shape we are going to split (so we don't alter shapeList)
424 shape = shapeList[0].copy()
425 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100426 new_shapeList = [shape.copy()]
427 length_on_axis = shape[axis]
428 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700429 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100430 # Calculate split on axis and remaining value
431 split_shape_val = int(shape[axis] / 2)
432 remaining_length = remaining_length - split_shape_val
433
434 # Append new shape, and set remaining shape
435 shape[axis] = split_shape_val
436 new_shapeList.append(shape.copy())
437 shape[axis] = remaining_length
438 if i == len(shapeList) - 3:
439 new_shapeList.append(shape.copy())
440
441 return new_shapeList
442
443
Eric Kunzee5e26762020-10-13 16:11:07 -0700444class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800445 """Argument generators create exhaustive or random lists of attributes for operators that take
446 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
447 tuples where the descriptive_name is appended to the test name and the arglist is expanded
448 as arguments to the operator build function."""
449
Eric Kunzee5e26762020-10-13 16:11:07 -0700450 def __init__(self):
451 pass
452
453 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100454 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800455 """A trivial argument generator for operators that don't take any
456 non-tensor arguments"""
457 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700458
459 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100460 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800461 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700462 axes = []
463
464 shape = shapeList[0]
465
466 for a in range(0, len(shape)):
Matthew Haddon43e37192021-07-09 14:13:02 +0100467 axes.append(("axis{}".format(a), [a]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700468 return axes
469
470 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100471 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700472 arg_list = []
473
474 ifm_shape = shapeList[0]
475 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100476 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
477 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700478
Les Bell7aa69f42021-09-20 10:44:07 +0100479 # Check the rank
480 rank = 5 if opName.startswith("conv3d") else 4
481 assert len(ifm_shape) == rank
482 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700483
Les Bell7aa69f42021-09-20 10:44:07 +0100484 # kernel rank omits batch and channels
485 k_rank = rank - 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700486
Les Bell7aa69f42021-09-20 10:44:07 +0100487 # Generate comprehensive argument lists
488 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
489 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
490 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
491 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
492 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
493 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700494
Les Bell7aa69f42021-09-20 10:44:07 +0100495 # add some oversize argument values
496 if max(ifm_shape) < 64:
497 bigPadding = 9
498 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
499 bigStride = 8
500 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
501 bigDilation = 7
502 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100503
504 # There are too many parameter combinations, so generate them sparsely
Les Bell7aa69f42021-09-20 10:44:07 +0100505 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
506 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
507 if sparsity < 13:
508 sparsity = 1
509 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
510 sparsity += 1
Les Bellf414b3c2021-09-06 11:29:46 +0100511 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100512 for s in sorted(list(strides)):
513 for p in sorted(list(paddings)):
514 for d in sorted(list(dilations)):
515 if (n % sparsity == 0
516 # padding must not exceed the kernel size ?
517 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
518 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
519 # the padded shape must exceed the kernel size
520 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
521 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
522 # the padded shape must exceed the dilation
523 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
524 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
525 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100526 arg_list.append(
527 (
528 "st{}_pad{}_dilat{}".format(
529 "".join([str(x) for x in s]),
530 "".join([str(x) for x in p]),
531 "".join([str(x) for x in d]),
532 ),
533 [s, p, d],
534 )
535 )
536 n += 1
537
Kevin Cheng1533b852021-09-01 12:51:58 -0700538 return arg_list
539
540 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100541 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700542 arg_list = []
543
544 ifm_shape = shapeList[0]
545 filter_shape = shapeList[1]
546
547 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800548 assert len(ifm_shape) == 4
549 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700550
Les Bell7aa69f42021-09-20 10:44:07 +0100551 # Generate comprehensive argument lists
552 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
553 paddings = {x for x in itertools.product(*([p_vals] * 2))}
554 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
555 strides = {x for x in itertools.product(*([s_vals] * 2))}
556 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
557 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700558
Les Bell7aa69f42021-09-20 10:44:07 +0100559 # add some oversize argument values
560 if max(ifm_shape) < 64:
561 bigPadding = 9
562 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
563 bigStride = 8
564 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
565 bigDilation = 7
566 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700567
Les Bell7aa69f42021-09-20 10:44:07 +0100568 # There are too many parameter combinations, so generate them sparsely
569 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
570 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
571 if sparsity < 13:
572 sparsity = 1
573 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
574 sparsity += 1
575 n = 0
576 for s in sorted(list(strides)):
577 for p in sorted(list(paddings)):
578 for d in sorted(list(dilations)):
579 if n % sparsity == 0:
580 # Determine the output shape
581 oh = (
582 ifm_shape[1]
583 - filter_shape[1]
584 - (filter_shape[1] - 1) * (d[0] - 1)
585 + 2 * p[0]
586 ) // s[0] + 1
587 ow = (
588 ifm_shape[2]
589 - filter_shape[2]
590 - (filter_shape[2] - 1) * (d[1] - 1)
591 + 2 * p[1]
592 ) // s[1] + 1
593 os = [ifm_shape[0], oh, ow, filter_shape[0]]
594 arg_list.append(
595 (
596 "st{}_pad{}_dilat{}_os{}".format(
597 "".join([str(x) for x in s]),
598 "".join([str(x) for x in p]),
599 "".join([str(x) for x in d]),
600 "x".join([str(x) for x in os]),
601 ),
602 [s, p, d, os],
603 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800604 )
Les Bell7aa69f42021-09-20 10:44:07 +0100605 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700606
607 return arg_list
608
609 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100610 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700611 arg_list = []
612 rank = len(shapeList[0])
613
Les Bell7ffccce2021-07-28 15:37:02 +0100614 # Exhaustively test combinations of padding on each side of each dimension
615 # - the range of padding values is defined by pad_min and pad_max
616 # - for padding >9, the name format needs to be more distinctive
617 pad_min, pad_max = 0, 1
618 pad_values = [x for x in range(pad_min, pad_max + 1)]
619 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
620 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700621
Les Bell7ffccce2021-07-28 15:37:02 +0100622 for paddings in shape_pad_values:
623 name = "pad"
624 for r in range(rank):
625 before, after = paddings[r]
626 name = f"{name}{before}{after}"
627 arg_list.append((name, [np.array(paddings)]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700628
629 return arg_list
630
631 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100632 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700633 arg_list = []
634
635 shape = shapeList[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800636 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700637
Les Bell7aa69f42021-09-20 10:44:07 +0100638 # Generate comprehensive argument lists
639 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
640 paddings = {x for x in itertools.product(*([p_vals] * 4))}
641 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
642 strides = {x for x in itertools.product(*([s_vals] * 2))}
643 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
644 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700645
Les Bell7aa69f42021-09-20 10:44:07 +0100646 # add some oversize argument values
647 bigStride = 7
648 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
649 bigKernel = 6
650 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
651 if max(shape) < 64:
652 # padding must be less than the kernel size
653 bigPadding = bigKernel - 1
654 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700655
Les Bell7aa69f42021-09-20 10:44:07 +0100656 # There are too many parameter combinations, so generate them sparsely
657 sparsity = len(paddings) * len(strides) * len(kernels) // 500 + 1
658 n = 0
659 for s in sorted(list(strides)):
660 for p in sorted(list(paddings)):
661 for k in sorted(list(kernels)):
662 if (n % sparsity == 0
663 # padding must not exceed the kernel size
664 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
665 # the padded shape must exceed the kernel size
666 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
667 ):
668 arg_list.append(
669 (
670 "st{}_kern{}_pad{}".format(
671 "".join([str(x) for x in s]),
672 "".join([str(x) for x in k]),
673 "".join([str(x) for x in p]),
674 ),
675 [s, p, k],
676 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800677 )
Les Bell7aa69f42021-09-20 10:44:07 +0100678 n += 1
679
Eric Kunzee5e26762020-10-13 16:11:07 -0700680 return arg_list
681
682 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100683 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700684 arg_list = []
685
686 # Enumerate the output types here
687 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800688 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700689 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800690 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700691 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800692 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700693 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800694 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700695 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800696 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700697 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800698 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700699
700 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800701 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700702
703 return arg_list
704
705 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100706 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700707 arg_list = []
708
709 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100710 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
711 if inDtype == DType.UINT8 and dtype != DType.INT8:
712 # The only output dtype for UINT8 is INT8, skip all other combinations
713 continue
714 if inDtype != DType.INT8 and dtype == DType.UINT8:
715 # The only input dtype for UINT8 is INT8, skip all other combinations
716 continue
717
Kevin Cheng550ccc52021-03-03 11:21:43 -0800718 for scale32 in [False, True]:
719 for double_round in [False, True]:
720 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700721
722 if inDtype == DType.INT48 and scale32:
723 # Illegal condition. Must be scale32=False
724 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100725 if double_round and not scale32:
726 # Illegal condition. ERROR_IF(!scale32 && double_round)
727 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700728
Kevin Cheng550ccc52021-03-03 11:21:43 -0800729 arg_list.append(
730 (
731 "out{}_sc{}_dr{}_pc{}".format(
732 DTypeNames[dtype],
733 int(scale32),
734 int(double_round),
735 int(per_channel),
736 ),
737 [dtype, scale32, double_round, per_channel],
738 )
739 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700740
741 return arg_list
742
Kevin Chengaee1fac2020-11-11 13:54:06 -0800743 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100744 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800745 arg_list = []
746
747 if dtype is DType.INT32:
748 for p in range(testGen.args.num_rand_permutations):
749
750 shift = testGen.randInt(0, 32)
751
Kevin Cheng550ccc52021-03-03 11:21:43 -0800752 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800753 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100754 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800755
756 return arg_list
757
758 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100759 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800760 arg_list = []
761
Kevin Cheng550ccc52021-03-03 11:21:43 -0800762 arg_list.append(("roundTrue", [True]))
763 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800764
765 return arg_list
766
Eric Kunzee5e26762020-10-13 16:11:07 -0700767 # Helper function for reshape. Gets some factors of a larger number.
768 @staticmethod
769 def getFactors(val, start=1):
770 factors = []
771
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100772 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700773 if (val % i) == 0:
774 factors.append(i)
775
776 return factors
777
778 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100779 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700780 arg_list = []
781
782 origShape = shapeList[0]
783
784 totalElements = 1
785 for s in origShape:
786 totalElements *= s
787
788 # This code is NOT fast. Fortunately, the numbers are fairly small.
789 factors = TosaArgGen.getFactors(totalElements)
790
791 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100792 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800793 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700794 continue
795
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100796 found = True
797 # escape_counter breaks while loop if it continues on for too long
798 escape_counter = 0
799 while found:
800 newShape = []
801 # Generate newShape ensuring it isn't a duplicate
802 remainingElements = totalElements
803 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100804 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100805 # pick rank-1 factors
806 newShape.append(shuffledFactors[0])
807 remainingElements = remainingElements // shuffledFactors[0]
808 shuffledFactors = testGen.rng.permutation(
809 TosaArgGen.getFactors(remainingElements)
810 )
811 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700812
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100813 # Toss in a -1 sometimes
814 minusOne = testGen.randInt(0, newRank * 4)
815 if minusOne < newRank:
816 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700817
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100818 # Check for duplicates
819 found = False
820 for name, other_shape in arg_list:
821 if other_shape[0] == newShape:
822 found = True
823 break
824
825 escape_counter += 1
826 if escape_counter >= 100:
827 break
828
829 if not found:
830 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700831
832 return arg_list
833
Eric Kunzee5e26762020-10-13 16:11:07 -0700834 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100835 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700836 arg_list = []
837
838 ifm_shape = shapeList[0]
839
Jeremy Johnsona6185572021-06-21 15:55:35 +0100840 # Get all permutations
841 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700842
Jeremy Johnsona6185572021-06-21 15:55:35 +0100843 # Limit to possible permutations from shape dimension or argument setting
844 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700845
Jeremy Johnsona6185572021-06-21 15:55:35 +0100846 # Get random permutation generator that uses all permutations
847 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700848
Jeremy Johnsona6185572021-06-21 15:55:35 +0100849 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -0700850 arg_list = [
851 ("perm{}".format(p), [random_permutations[p].tolist()])
852 for p in range(limit)
853 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700854 return arg_list
855
856 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100857 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700858 arg_list = []
859
860 ifm_shape = shapeList[0]
861 rank = len(ifm_shape)
862
863 for p in range(testGen.args.num_rand_permutations):
864 begin = []
865 size = []
866
Kevin Cheng550ccc52021-03-03 11:21:43 -0800867 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700868
869 for i in range(rank):
870 if ifm_shape[i] > 1:
871 begin.append(testGen.randInt(0, ifm_shape[i]))
872 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
873
874 # Invalid slice size?
875 if size[i] == 0:
876 valid = False
877 else:
878 begin.append(0)
879 size.append(1)
880
881 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800882 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700883 return arg_list
884
885 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100886 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700887 arg_list = []
888
889 ifm_shape = shapeList[0]
890 rank = len(ifm_shape)
891
892 for p in range(testGen.args.num_rand_permutations):
893
894 # Pick a few random, but small multiple values
895 # because otherwise this has a tendency to generate
896 # enormous tensors
897 multiples = []
898 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +0100899 if ifm_shape[i] > 1000:
900 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
901 multiples.append(1)
902 elif max(ifm_shape) > 1000:
903 multiples.append(2)
904 else:
905 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800906 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700907
908 return arg_list
909
910 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100911 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700912 arg_list = []
913
914 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +0100915 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700916
917 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +0100918 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100919 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +0100920 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800921 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +0100922 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100923 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +0100924 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800925 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800926 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800927 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +0100928 elif error_name == ErrorIf.WrongInputType:
929 # If an incorrect input type is used then we set a 'correct'
930 # output type to avoid other errors
931 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700932 else:
933 continue
934
935 for outputDType in outputDTypeList:
936 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -0700937 # Randomly generate legal output dimensions and shift
938 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +0100939 # A output_dim of 1 will cause offset to exceed allowed range
940 # so minimum value 2 produced below
941 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
942 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
943 output_dims[0] += 1
944 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
945 output_dims[1] += 1
946
Kevin Cheng77d0f762020-11-24 10:26:32 -0800947 in_center_h = (ifm_shape[1] - 1) / 2.0
948 in_center_w = (ifm_shape[2] - 1) / 2.0
949 out_center_h = (output_dims[0] - 1) / 2.0
950 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700951
Kevin Cheng77d0f762020-11-24 10:26:32 -0800952 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
953 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
954 fp_offset_y = in_center_h - fp_stride_y * out_center_h
955 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -0700956
Kevin Cheng77d0f762020-11-24 10:26:32 -0800957 if outputDType == DType.FLOAT:
958 shift = 0
959 stride = [0, 0]
960 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800961 stride_fp = [fp_stride_y, fp_stride_x]
962 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +0100963
964 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +0100965 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +0100966 testGen,
967 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +0100968 mode,
969 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +0100970 shapeList,
971 outputDType,
972 shift,
973 stride,
974 stride_fp,
975 offset,
976 offset_fp
977 )
Matthew Haddon848efb42021-09-09 12:30:53 +0100978 else:
979 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +0100980
Kevin Cheng550ccc52021-03-03 11:21:43 -0800981 arg_list.append(
982 (
983 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +0100984 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800985 output_dims[0],
986 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +0100987 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -0800988 stride_fp[0],
989 stride_fp[1],
990 offset_fp[0],
991 offset_fp[1],
992 ),
993 [
Matthew Haddon848efb42021-09-09 12:30:53 +0100994 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800995 stride,
996 offset,
997 shift,
998 stride_fp,
999 offset_fp,
1000 output_dims,
1001 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001002 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001003 ],
1004 )
1005 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001006 else:
1007 shift = 11
1008 unit = float(1 << shift)
1009 stride_y = int(round(fp_stride_y * unit))
1010 stride_x = int(round(fp_stride_x * unit))
1011 offset_y = int(round(fp_offset_y * unit))
1012 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001013
Kevin Cheng550ccc52021-03-03 11:21:43 -08001014 while (
Matthew Haddone86fd342021-09-07 16:12:21 +01001015 stride_y >= (16 << shift)
1016 or stride_x >= (16 << shift)
1017 or offset_y >= (16 << shift)
1018 or offset_x >= (16 << shift)
1019 or offset_y <= (-16 << shift)
1020 or offset_x <= (-16 << shift)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001021 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001022 shift = shift - 1
1023 unit = float(1 << shift)
1024 stride_y = int(round(fp_stride_y * unit))
1025 stride_x = int(round(fp_stride_x * unit))
1026 offset_y = int(round(fp_offset_y * unit))
1027 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001028
Kevin Cheng550ccc52021-03-03 11:21:43 -08001029 stride = [stride_y, stride_x]
1030 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001031
1032 stride_fp = [0.0, 0.0]
1033 offset_fp = [0.0, 0.0]
1034
Matthew Haddone86fd342021-09-07 16:12:21 +01001035 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001036 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001037 testGen,
1038 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001039 mode,
1040 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001041 shapeList,
1042 outputDType,
1043 shift,
1044 stride,
1045 stride_fp,
1046 offset,
1047 offset_fp
1048 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001049 else:
1050 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001051
Kevin Cheng550ccc52021-03-03 11:21:43 -08001052 arg_list.append(
1053 (
1054 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001055 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001056 shift,
1057 output_dims[0],
1058 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001059 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001060 stride[0],
1061 stride[1],
1062 offset[0],
1063 offset[1],
1064 ),
1065 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001066 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001067 stride,
1068 offset,
1069 shift,
1070 stride_fp,
1071 offset_fp,
1072 output_dims,
1073 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001074 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001075 ],
1076 )
1077 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001078
1079 return arg_list
1080
Matthew Haddon1c00b712021-10-01 15:51:03 +01001081 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001082 # CondIf generates the condition values here.
1083 # Convert to tensors in the build function, along with the
1084 # then and else blocks
1085 arg_list = []
1086
1087 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001088 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001089
1090 return arg_list
1091
Matthew Haddon1c00b712021-10-01 15:51:03 +01001092 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001093 # While loop: 0 iterations, 1, more than 1
1094 arg_list = []
1095
1096 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001097 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001098
1099 return arg_list
1100
Matthew Haddone86fd342021-09-07 16:12:21 +01001101class TosaErrorIfArgGen:
1102
1103 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001104 def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
Matthew Haddone86fd342021-09-07 16:12:21 +01001105
1106 if outputDType == DType.FLOAT:
1107 if error_name == ErrorIf.StrideSmallerEqualZero:
1108 stride_fp = testGen.rng.random(size=[2]) - 2
1109 elif error_name == ErrorIf.ShiftNotZero:
1110 shift = testGen.rng.integers(1, 5)
1111 elif error_name == ErrorIf.StrideLargerDimension:
1112 shape = shapeList[0]
1113 transform_height = testGen.rng.choice([False, True])
1114 if transform_height:
1115 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1116 else:
1117 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1118 else:
1119 if error_name == ErrorIf.StrideSmallerEqualZero:
1120 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1121 elif error_name == ErrorIf.ShiftSmallerOne:
1122 shift = testGen.rng.integers(-3, 1)
1123 if shift <= 0:
1124 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1125 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1126 else:
1127 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1128 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1129 elif error_name == ErrorIf.ShiftLargerEleven:
1130 shift = np.int16(testGen.rng.integers(12, 15))
1131 elif error_name == ErrorIf.StrideLargerDimension:
1132 shape = shapeList[0]
1133 transform_height = testGen.rng.choice([False, True])
1134 if transform_height:
1135 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1136 else:
1137 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1138 elif error_name == ErrorIf.StrideLargerEqualMax:
1139 stride = [(16 << shift) + 1, (16 << shift) + 1]
1140 elif error_name == ErrorIf.OffsetLargerEqualMax:
1141 offset = [(16 << shift) + 1, (16 << shift) + 1]
1142 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1143 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1144
Matthew Haddon1c00b712021-10-01 15:51:03 +01001145
Matthew Haddon848efb42021-09-09 12:30:53 +01001146 if error_name == ErrorIf.WrongOutputType:
1147 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1148 incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1149 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1150 incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
1151 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1152 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
1153 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1154 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
1155 elif dtype == DType.FLOAT:
1156 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
1157 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001158
Matthew Haddon848efb42021-09-09 12:30:53 +01001159 return shift, stride, stride_fp, offset, offset_fp, outputDType
1160
1161 @staticmethod
1162 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1163 # Mess up input/output tensors for ERROR_IF checks
1164 if error_name == "WrongInputList":
1165 add_input = testGen.rng.choice([True, False])
1166 if add_input:
1167 input_list.append('eiDummyInput')
1168 else:
1169 input_list = input_list[:-1]
1170 if error_name == "WrongOutputList":
1171 add_output = testGen.rng.choice([True, False])
1172 if add_output:
1173 output_list.append('eiDummyOutput')
1174 else:
1175 output_list = []
1176 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001177
1178class TosaErrorValidator:
1179
Matthew Haddon848efb42021-09-09 12:30:53 +01001180 @staticmethod
1181 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
1182 # Check ERROR_IF statements
1183
1184 for val_fcn in validator_fcns:
1185 val_result = val_fcn(True, **kwargs)
1186
1187 validator_name = val_result['error_name']
1188 error_result = val_result['error_result']
1189 error_reason = val_result['error_reason']
1190
1191 if error_result:
1192 if error_name == validator_name:
1193 serializer.setExpectedReturnCode(2, error_reason)
1194 else:
1195 print(f"Multiple ERROR_IF checks hit \nError required: {error_name}, Error_produced: {validator_name}")
1196 return None # Return None to delete test if wrong ERROR_IF is hit
1197 else:
1198 if error_name == validator_name:
1199 print(f"No ERROR_IF hit for {error_name}")
1200 return None
1201
1202 @staticmethod
1203 def evWrongInputType(check=False, **kwargs):
1204 all_dtypes = (DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1205
1206 # Find the unsupported input data types
1207 assert 'op' in kwargs
1208 op = kwargs['op']
1209 input_dtypes = op['types']
1210 wrong_input_dtypes = list(set(all_dtypes) - set(input_dtypes))
1211
1212 error_name = ErrorIf.WrongInputType
1213 param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
1214 error_result = False
1215 error_reason = "Input data type not supported for this operator"
1216
1217 if check:
1218 input_dtype = kwargs['input_dtype']
1219 if input_dtype not in input_dtypes:
1220 error_result = True
1221
1222 info_dict = {
1223 "error_name": error_name,
1224 "error_result": error_result,
1225 "error_reason": error_reason,
1226 "param_reqs": param_reqs
1227 }
1228 return info_dict
1229
1230 @staticmethod
1231 def evWrongOutputType(check=False, **kwargs):
1232 error_name = ErrorIf.WrongOutputType
1233 param_reqs = {"rank": None, "dtype": None, "shape": None}
1234 error_result = False
1235 error_reason = "Output data type not supported for this configuration of operator"
1236
1237 if check:
1238 input_dtype = kwargs['input_dtype']
1239 output_dtype = kwargs['output_dtype']
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001240 op = kwargs['op']
Matthew Haddon848efb42021-09-09 12:30:53 +01001241
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001242 if op['op'] == Op.RESIZE:
1243 mode = kwargs['mode']
1244 if (
1245 (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
1246 (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
1247 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1248 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1249 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1250 ):
1251 error_result = True
1252 else:
1253 if output_dtype != input_dtype:
1254 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001255
1256 info_dict = {
1257 "error_name": error_name,
1258 "error_result": error_result,
1259 "error_reason": error_reason,
1260 "param_reqs": param_reqs
1261 }
1262 return info_dict
1263
1264 @staticmethod
1265 def evWrongRank(check=False, **kwargs):
1266 all_ranks = (1, 2, 3, 4, 5)
1267
1268 # Make a list of incorrect ranks
1269 assert 'op' in kwargs
1270 op = kwargs['op']
1271 rmin, rmax = op['rank']
1272 rank_range = range(rmin, rmax + 1)
1273 incorrect_ranks = list(set(all_ranks) - set(rank_range))
1274 # Set minimum incorrect rank to 3 to avoid index error
1275 if op['op'] == Op.RESIZE:
1276 incorrect_ranks = [3, 5]
1277
1278 error_name = ErrorIf.WrongRank
1279 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1280 error_result = False
1281 error_reason = "Rank not supported for this operator"
1282
1283 if check:
1284 input_shape = kwargs['input_shape']
1285 if op['op'] == Op.RESIZE and len(input_shape.shape) != 4:
1286 error_result = True
1287
1288 info_dict = {
1289 "error_name": error_name,
1290 "error_result": error_result,
1291 "error_reason": error_reason,
1292 "param_reqs": param_reqs
1293 }
1294 return info_dict
1295
1296 @staticmethod
1297 def evWrongInputList(check=False, **kwargs):
1298 error_name = ErrorIf.WrongInputList
1299 param_reqs = {"rank": None, "dtype": None, "shape": None}
1300 error_result = False
1301 error_reason = "Op input list does not match expected input"
1302
1303 if check:
1304 op = kwargs['op']
1305 input_list = kwargs['input_list']
1306 num_operands = kwargs['num_operands']
1307 if len(input_list) != num_operands:
1308 error_result = True
1309
1310 info_dict = {
1311 "error_name": error_name,
1312 "error_result": error_result,
1313 "error_reason": error_reason,
1314 "param_reqs": param_reqs
1315 }
1316 return info_dict
1317
1318 @staticmethod
1319 def evWrongOutputList(check=False, **kwargs):
1320 error_name = ErrorIf.WrongOutputList
1321 param_reqs = {"rank": None, "dtype": None, "shape": None}
1322 error_result = False
1323 error_reason = "Op output list does not match expected output"
1324
1325 if check:
1326 output_list = kwargs['output_list']
1327 # Note this will be incorrect if an operator returns more than one output
1328 if len(output_list) != 1:
1329 error_result = True
1330
1331 info_dict = {
1332 "error_name": error_name,
1333 "error_result": error_result,
1334 "error_reason": error_reason,
1335 "param_reqs": param_reqs
1336 }
1337 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01001338
1339 @staticmethod
1340 def evMaxDimExceeded(check=False, **kwargs):
1341 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01001342 param_reqs = {
1343 "rank": [4,4],
1344 "dtype": [DType.INT8],
1345 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
1346 }
Matthew Haddone86fd342021-09-07 16:12:21 +01001347 error_result = False
1348 error_reason = "At least one maximum dimension is larger than 16384"
1349
1350 if check:
1351 input_shape = kwargs['input_shape'].shape
1352 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
1353 if ((input_shape[1] > 16384) or
1354 (input_shape[2] > 16384) or
1355 (output_shape[0] > 16384) or
1356 (output_shape[1] > 16384)):
1357 error_result = True
1358
1359 info_dict = {
1360 "error_name": error_name,
1361 "error_result": error_result,
1362 "error_reason": error_reason,
1363 "param_reqs": param_reqs
1364 }
1365 return info_dict
1366
1367 @staticmethod
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001368 def evBatchMismatch(check=False, **kwargs):
1369 error_name = ErrorIf.BatchMismatch
1370 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1371 error_result = False
1372 error_reason = "Input batch size not equal to output batch size"
1373
1374 assert 'op' in kwargs
1375 op = kwargs['op']
1376 rmin, rmax = op['rank']
1377 rank_range = range(rmin, rmax + 1)
1378
1379 if check:
1380 input_shape = kwargs['input_shape'].shape
1381 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1382
1383 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
1384 error_result = True
1385
1386 info_dict = {
1387 "error_name": error_name,
1388 "error_result": error_result,
1389 "error_reason": error_reason,
1390 "param_reqs": param_reqs
1391 }
1392 return info_dict
1393
1394 @staticmethod
1395 def evChannelMismatch(check=False, **kwargs):
1396 error_name = ErrorIf.ChannelMismatch
1397 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1398 error_result = False
1399 error_reason = "Input channel size not equal to output channel size"
1400
1401 assert 'op' in kwargs
1402 op = kwargs['op']
1403 rmin, rmax = op['rank']
1404 rank_range = range(rmin, rmax + 1)
1405
1406 if check:
1407 input_shape = kwargs['input_shape'].shape
1408 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1409 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
1410 error_result = True
1411
1412 info_dict = {
1413 "error_name": error_name,
1414 "error_result": error_result,
1415 "error_reason": error_reason,
1416 "param_reqs": param_reqs
1417 }
1418 return info_dict
1419
1420 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001421 def evStrideSmallerEqualZero(check=False, **kwargs):
1422 error_name = ErrorIf.StrideSmallerEqualZero
1423 param_reqs = {"rank": None, "dtype": None, "shape": None}
1424 error_result = False
1425 error_reason = "Stride value smaller than or equal zero"
1426
1427 if check:
1428 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001429 output_dtype = kwargs['output_dtype']
1430 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
1431 stride = kwargs['stride'] # Work around wrong input/output type tests
1432 elif output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001433 stride = kwargs['stride_fp']
Matthew Haddon848efb42021-09-09 12:30:53 +01001434 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1435 stride = kwargs['stride_fp'] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01001436 else:
1437 stride = kwargs['stride']
1438
1439 if min(stride) <= 0:
1440 error_result = True
1441
1442 info_dict = {
1443 "error_name": error_name,
1444 "error_result": error_result,
1445 "error_reason": error_reason,
1446 "param_reqs": param_reqs
1447 }
1448 return info_dict
1449
1450 @staticmethod
1451 def evStrideLargerEqualMax(check=False, **kwargs):
1452 error_name = ErrorIf.StrideLargerEqualMax
1453 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1454 error_result = False
1455 error_reason = "Stride value larger than or equal to maximum value"
1456
1457 if check:
1458 shift = kwargs['shift']
1459 input_dtype = kwargs['input_dtype']
1460 stride = kwargs['stride']
1461 if input_dtype in [DType.INT8, DType.INT16]:
1462 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1463 error_result = True
1464 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1465 error_result = True
1466
1467 info_dict = {
1468 "error_name": error_name,
1469 "error_result": error_result,
1470 "error_reason": error_reason,
1471 "param_reqs": param_reqs
1472 }
1473 return info_dict
1474
1475
1476 @staticmethod
1477 def evStrideLargerDimension(check=False, **kwargs):
1478 error_name = ErrorIf.StrideLargerDimension
1479 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1480 error_result = False
1481 error_reason = "Stride value larger than or equal to H/W dimension"
1482
1483 if check:
1484 shape = kwargs['input_shape'].shape
1485 input_dtype = kwargs['input_dtype']
1486 stride = kwargs['stride_fp']
1487
1488 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1489 error_result = True
1490
1491 info_dict = {
1492 "error_name": error_name,
1493 "error_result": error_result,
1494 "error_reason": error_reason,
1495 "param_reqs": param_reqs
1496 }
1497 return info_dict
1498
1499
1500 @staticmethod
1501 def evOffsetSmallerEqualMin(check=False, **kwargs):
1502 error_name = ErrorIf.OffsetSmallerEqualMin
1503 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1504 error_result = False
1505 error_reason = "Offset value smaller than or equal to minimum value"
1506
1507 if check:
1508 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001509 output_dtype = kwargs['output_dtype']
1510 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001511 offset = kwargs['offset_fp']
1512 else:
1513 offset = kwargs['offset']
1514
1515 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
1516 error_result = True
1517 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
1518 error_result = True
1519
1520 info_dict = {
1521 "error_name": error_name,
1522 "error_result": error_result,
1523 "error_reason": error_reason,
1524 "param_reqs": param_reqs
1525 }
1526 return info_dict
1527
1528 @staticmethod
1529 def evOffsetLargerEqualMax(check=False, **kwargs):
1530 error_name = ErrorIf.OffsetLargerEqualMax
1531 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1532 error_result = False
1533 error_reason = "Offset value larger than or equal to maximum value"
1534
1535 if check:
1536 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001537 output_dtype = kwargs['output_dtype']
1538 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001539 offset = kwargs['offset_fp']
1540 else:
1541 offset = kwargs['offset']
1542
1543 if shift >= 0:
1544 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
1545 error_result = True
1546
1547 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
1548 error_result = True
1549 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
1550 error_result = True
1551
1552 info_dict = {
1553 "error_name": error_name,
1554 "error_result": error_result,
1555 "error_reason": error_reason,
1556 "param_reqs": param_reqs
1557 }
1558 return info_dict
1559
1560 @staticmethod
1561 def evShiftNotZero(check=False, **kwargs):
1562 error_name = ErrorIf.ShiftNotZero
1563 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1564 error_result = False
1565 error_reason = "Shift value must be zero for float input"
1566
1567 if check:
1568 shift = kwargs['shift']
1569 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001570 output_dtype = kwargs['output_dtype']
1571 if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
Matthew Haddone86fd342021-09-07 16:12:21 +01001572 error_result = True
1573
1574 info_dict = {
1575 "error_name": error_name,
1576 "error_result": error_result,
1577 "error_reason": error_reason,
1578 "param_reqs": param_reqs
1579 }
1580 return info_dict
1581
1582
1583 @staticmethod
1584 def evShiftSmallerOne(check=False, **kwargs):
1585 error_name = ErrorIf.ShiftSmallerOne
1586 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1587 error_result = False
1588 error_reason = "Shift value smaller than one"
1589
1590 if check:
1591 shift = kwargs['shift']
1592 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001593 output_dtype = kwargs['output_dtype']
1594 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001595 error_result = True
1596
1597 info_dict = {
1598 "error_name": error_name,
1599 "error_result": error_result,
1600 "error_reason": error_reason,
1601 "param_reqs": param_reqs
1602 }
1603 return info_dict
1604
1605 @staticmethod
1606 def evShiftLargerEleven(check=False, **kwargs):
1607 error_name = ErrorIf.ShiftLargerEleven
1608 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1609 error_result = False
1610 error_reason = "Shift value larger than eleven"
1611
1612 if check:
1613 shift = kwargs['shift']
1614 if shift > 11:
1615 error_result = True
1616
1617 info_dict = {
1618 "error_name": error_name,
1619 "error_result": error_result,
1620 "error_reason": error_reason,
1621 "param_reqs": param_reqs
1622 }
1623 return info_dict
1624
1625
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001626 @staticmethod
1627 def evRankMismatch(check=False, **kwargs):
1628 error_name = ErrorIf.RankMismatch
1629 param_reqs = {"rank": None, "dtype": None, "shape": None}
1630 error_result = False
1631 error_reason = "Input Rank does not match output rank"
1632
1633 if check:
1634 input1_shape = kwargs['input1'].shape
1635 input2_shape = kwargs['input2'].shape
1636 output_shape = kwargs['result_tensor'].shape
1637 if (len(input1_shape) != len(output_shape)) or (len(input2_shape) != len(output_shape)):
1638 error_result = True
1639
1640 info_dict = {
1641 "error_name": error_name,
1642 "error_result": error_result,
1643 "error_reason": error_reason,
1644 "param_reqs": param_reqs
1645 }
1646 return info_dict
1647
1648
Matthew Haddonb724efc2021-08-25 16:40:29 +01001649class TosaInvalidValidator:
1650
1651 @staticmethod
1652 def ivWrongDataTypeOrModeResize(**kwargs):
1653 input_dtype = kwargs["input_dtype"]
1654 args = kwargs["args"]
1655 mode = args[0]
1656 stride = args[1]
1657 stride_fp = args[4]
1658 output_dtype = args[8]
1659
1660 if mode == ResizeMode.BILINEAR:
1661 # Invalid output data type / Invalid input datatype
1662 return (
1663 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
1664 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
1665 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
1666 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
1667 )
1668 elif mode == ResizeMode.NEAREST:
1669 # Invalid output data type / Invalid input datatype
1670 return (
1671 (input_dtype != output_dtype) or
1672 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
1673 )
1674 else:
1675 # Invalid resize mode
1676 return True
1677
1678 @staticmethod
1679 def ivBadStride(**kwargs):
1680 input_dtype = kwargs["input_dtype"]
1681 args = kwargs["args"]
1682 stride_x = args[1][0]
1683 stride_y = args[1][1]
1684 stride_fp_x = args[4][0]
1685 stride_fp_y = args[4][1]
1686
1687 if input_dtype == DType.FLOAT:
1688 if stride_fp_x <= 0 or stride_fp_y <= 0:
1689 # Negative or zero stride
1690 return True
1691 else:
1692 if stride_x <= 0 or stride_y <= 0:
1693 # Negative or zero stride
1694 return True
1695 return False
1696
1697
Matthew Haddonb724efc2021-08-25 16:40:29 +01001698 @staticmethod
1699 def ivHeightWidthSmallerZero(**kwargs):
1700 opName = kwargs['opName']
1701
1702 inputShapes = kwargs['shapeList']
1703 input = inputShapes[0]
1704 if not opName.endswith("pool2d"):
1705 filter = inputShapes[1]
1706
1707 args = kwargs['args']
1708 strides = args[0]
1709 padding = args[1]
1710 dilations = args[2]
1711 if opName.endswith("pool2d"):
1712 kernel = args[2]
1713
1714 if opName.startswith('conv2d'):
1715 h = (
1716 input[1]
1717 - filter[1]
1718 - (filter[1] - 1) * (dilations[0] - 1)
1719 + padding[0]
1720 + padding[1]
1721 ) // strides[0] + 1
1722
1723 w = (
1724 input[2]
1725 - filter[2]
1726 - (filter[2] - 1) * (dilations[1] - 1)
1727 + padding[2]
1728 + padding[3]
1729 ) // strides[1] + 1
1730 elif opName.startswith("depthwise_conv2d"):
1731 h = (
1732 input[1]
1733 - filter[0]
1734 - (filter[0] - 1) * (dilations[0] - 1)
1735 + padding[0]
1736 + padding[1]
1737 ) // strides[0] + 1
1738
1739 w = (
1740 input[2]
1741 - filter[1]
1742 - (filter[1] - 1) * (dilations[1] - 1)
1743 + padding[2]
1744 + padding[3]
1745 ) // strides[1] + 1
1746 elif opName.endswith("pool2d"):
1747 h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
1748 w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
1749 else:
1750 assert False, "Unrecognized Op"
1751
1752 if h <= 0 or w <= 0:
1753 # Invalid parameter combination
1754 return True
1755 return False
1756
1757 @staticmethod
1758 def ivNonPositiveOutputShape(**kwargs):
1759 args = kwargs['args']
1760 output_shape = args[3]
1761 if output_shape[1] <= 0 or output_shape[2] <= 0:
1762 # Negative output shape
1763 return True
1764 return False
1765
1766
Kevin Cheng550ccc52021-03-03 11:21:43 -08001767
Eric Kunzee5e26762020-10-13 16:11:07 -07001768class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001769 # Maximum rank of tensor supported by test generator.
1770 TOSA_TENSOR_MAX_RANK = 6
1771
Eric Kunzee5e26762020-10-13 16:11:07 -07001772 def __init__(self, args):
1773 self.args = args
1774 self.basePath = args.output_dir
1775 self.random_seed = args.random_seed
1776 self.ser = None
1777 self.rng = np.random.default_rng(self.random_seed)
1778 self.createDynamicOpLists()
1779 self.initOpListDefaults()
1780 self.quantGen = TosaQuantGen()
1781 # Force makeShape to do a specific starting shape
1782 self.targetted_shape = None
1783
1784 def createSerializer(self, opName, testPath):
1785 self.testPath = os.path.join(opName, testPath)
1786
1787 fullPath = os.path.join(self.basePath, self.testPath)
1788 os.makedirs(fullPath, exist_ok=True)
1789 self.ser = ts.TosaSerializer(fullPath)
1790
1791 def getSerializer(self):
1792 return self.ser
1793
1794 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001795 with open(
1796 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
1797 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07001798 fd.write(self.ser.serialize())
1799
Kevin Cheng550ccc52021-03-03 11:21:43 -08001800 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
1801 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07001802
Matthew Haddon74567092021-07-16 15:38:20 +01001803 def resetRNG(self, seed=None):
1804 if seed == None:
1805 seed = self.random_seed + 1
1806 self.rng = np.random.default_rng(seed)
1807
Eric Kunzee5e26762020-10-13 16:11:07 -07001808 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07001809 if dtype == DType.BOOL:
1810 np_dt = np.bool
1811 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07001812 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07001813 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07001814 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001815 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001816 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
1817 elif dtype == DType.UINT8:
1818 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001819 elif dtype == DType.INT16:
1820 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
1821 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001822 return np.int32(
1823 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
1824 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001825 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001826 return np.int64(
1827 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
1828 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001829 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001830 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001831 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001832 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001833
Kevin Cheng989cb052021-04-28 16:29:44 -07001834 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07001835 placeholders = []
1836
Kevin Cheng989cb052021-04-28 16:29:44 -07001837 assert len(shape_list) == len(dtype_list)
1838
1839 for idx, shape in enumerate(shape_list):
1840 arr = self.getRandTensor(shape, dtype_list[idx])
1841 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001842
1843 return placeholders
1844
Kevin Cheng989cb052021-04-28 16:29:44 -07001845 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07001846 consts = []
1847
Kevin Cheng989cb052021-04-28 16:29:44 -07001848 assert len(shape_list) == len(dtype_list)
1849
1850 for idx, shape in enumerate(shape_list):
1851 arr = self.getRandTensor(shape, dtype_list[idx])
1852 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001853
1854 return consts
1855
1856 def makeShape(self, rank):
1857 if self.targetted_shape:
1858 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001859 return np.int32(
1860 self.rng.integers(
1861 low=self.args.tensor_shape_range[0],
1862 high=self.args.tensor_shape_range[1],
1863 size=rank,
1864 )
1865 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001866
1867 def setTargetShape(self, shape):
1868 self.targetted_shape = shape
1869
1870 def randInt(self, low=0, high=256):
1871 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
1872
1873 def getRandNumberDType(self, dtype):
1874 if dtype == DType.FLOAT:
1875 return self.rng.random()
1876 elif dtype == DType.BOOL:
1877 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07001878 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07001879 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07001880 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001881 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001882 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07001883 elif dtype == DType.INT16:
1884 low, high = (-32768, 32768)
1885 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001886 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07001887 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001888 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07001889 # Special size
1890 return np.int64(self.rng.integers(low, high, size=1))[0]
1891 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001892 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001893
1894 return np.int32(self.rng.integers(low, high, size=1))[0]
1895
1896 def shapeStr(self, shape):
1897
1898 sStr = []
1899 # Convert to strings
1900 for i in shape:
1901 sStr.append(str(i))
1902
Kevin Cheng550ccc52021-03-03 11:21:43 -08001903 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001904
1905 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07001906 if isinstance(t, list):
1907 assert len(t) >= 2
1908 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001909 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001910 if t == DType.BOOL:
1911 return "b"
1912 elif t == DType.INT4:
1913 return "i4"
1914 elif t == DType.INT8:
1915 return "i8"
1916 elif t == DType.UINT8:
1917 return "u8"
1918 elif t == DType.INT16:
1919 return "i16"
1920 elif t == DType.INT32:
1921 return "i32"
1922 elif t == DType.INT48:
1923 return "i48"
1924 elif t == DType.FLOAT:
1925 return "float"
1926 else:
1927 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001928
1929 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001930 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08001931 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07001932 return 4
1933 elif t == DType.INT8:
1934 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08001935 elif t == DType.UINT8:
1936 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07001937 elif t == DType.INT16:
1938 return 16
1939 elif t == DType.INT32:
1940 return 32
1941 elif t == DType.INT48:
1942 return 48
1943 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001944 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001945
1946 # Argument generators
1947 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
1948 # Where the string descriptor is used to generate the test name and
1949 # The build_fcn_arg_list is expanded and passed to the operator test
1950 # build function
1951
Kevin Cheng550ccc52021-03-03 11:21:43 -08001952 def build_unary(self, op, a, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001953 result_tens = OutputShaper.unaryOp(self.ser, a)
Matthew Haddon848efb42021-09-09 12:30:53 +01001954 # build_placeholder returns an int, ABS/other ops does not
1955 if isinstance(op, int):
1956 self.ser.addOperator(op, [a.name], [result_tens.name], None, qinfo)
1957 else:
1958 self.ser.addOperator(op['op'], [a.name], [result_tens.name], None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07001959 return result_tens
1960
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001961 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
1962 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
1963
1964
1965 # Invalidate Input/Output list for error if checks.
1966 input_list = [a.name, b.name]
1967 output_list = [result_tens.name]
1968 pCount, cCount = op["operands"]
1969 num_operands = pCount + cCount
1970 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
1971
1972 TosaErrorValidator.evValidateErrorIfs(
1973 self.ser,
1974 validator_fcns,
1975 error_name,
1976 op=op,
1977 input1 = a,
1978 input2 = b,
1979 input_dtype = a.dtype,
1980 output_dtype = result_tens.dtype,
1981 result_tensor = result_tens,
1982 input_list=input_list,
1983 output_list=output_list,
1984 num_operands=num_operands,
1985 )
1986
1987 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001988 return result_tens
1989
1990 def build_binary_nonbroadcast(self, op, a, b):
1991 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01001992 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001993 return result_tens
1994
Kevin Chengaee1fac2020-11-11 13:54:06 -08001995 def build_arithmetic_right_shift(self, op, a, b, round):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001996 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Kevin Chengaee1fac2020-11-11 13:54:06 -08001997
1998 attr = ts.TosaSerializerAttribute()
1999 attr.ArithmeticRightShiftAttribute(round)
2000
Matthew Haddon848efb42021-09-09 12:30:53 +01002001 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002002 return result_tens
2003
2004 def build_mul(self, op, a, b, shift):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002005 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Eric Kunzee5e26762020-10-13 16:11:07 -07002006
2007 # Special for multiply:
2008 # Force the result to INT32 for INT types
2009 if a.dtype != DType.FLOAT:
2010 result_tens.setDtype(DType.INT32)
2011
Kevin Chengaee1fac2020-11-11 13:54:06 -08002012 attr = ts.TosaSerializerAttribute()
2013 attr.MulAttribute(shift)
2014
Matthew Haddon848efb42021-09-09 12:30:53 +01002015 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002016 return result_tens
2017
2018 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002019 # Constant size depending on type, random values
2020 if a.dtype == DType.INT16:
Kevin Chengacb550f2021-06-29 15:32:19 -07002021 table_dtype = DType.INT16
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002022 table_arr = self.getRandTensor([513], table_dtype)
2023 else:
2024 assert a.dtype == DType.INT8
2025 table_dtype = DType.INT8
2026 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002027
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002028 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
2029 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01002030 self.ser.addOperator(op['op'], [a.name, table_tens.name], [result_tens.name], None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002031
2032 return result_tens
2033
2034 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07002035 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002036 self.ser.addOperator(op['op'], [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002037 return result_tens
2038
2039 def build_comparison(self, op, a, b):
2040 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002041 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002042 return result_tens
2043
2044 def build_argmax(self, op, a, axis):
2045 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
2046
2047 attr = ts.TosaSerializerAttribute()
2048 attr.AxisAttribute(axis)
2049
Matthew Haddon848efb42021-09-09 12:30:53 +01002050 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002051 return result_tens
2052
Matthew Haddonb724efc2021-08-25 16:40:29 +01002053 def build_pool2d(self, op, input, stride, pad, kernel, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07002054 result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
2055
2056 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002057 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07002058
Matthew Haddon848efb42021-09-09 12:30:53 +01002059 self.ser.addOperator(op['op'], [input.name], [result_tens.name], attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002060 return result_tens
2061
2062 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002063 assert len(padding) == 4
2064 result_tens = OutputShaper.conv2dOp(
2065 self.ser, ifm, filter, strides, padding, dilations
2066 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002067
2068 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002069 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07002070
Kevin Cheng550ccc52021-03-03 11:21:43 -08002071 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002072 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002073 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002074 return result_tens
2075
Kevin Cheng1533b852021-09-01 12:51:58 -07002076 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
2077 assert len(padding) == 6
2078 result_tens = OutputShaper.conv3dOp(
2079 self.ser, ifm, filter, strides, padding, dilations
2080 )
2081
2082 attr = ts.TosaSerializerAttribute()
2083 attr.ConvAttribute(padding, strides, dilations)
2084
2085 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002086 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng1533b852021-09-01 12:51:58 -07002087 )
2088 return result_tens
2089
Kevin Cheng550ccc52021-03-03 11:21:43 -08002090 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07002091 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002092 ):
2093 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07002094 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
2095
2096 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002097 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002098
Kevin Cheng550ccc52021-03-03 11:21:43 -08002099 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002100 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002101 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002102 return result_tens
2103
Kevin Cheng550ccc52021-03-03 11:21:43 -08002104 def build_depthwise_conv2d(
2105 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
2106 ):
2107 result_tens = OutputShaper.depthwiseConv2dOp(
2108 self.ser, ifm, filter, strides, padding, dilations
2109 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002110
2111 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002112 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07002113
Kevin Cheng550ccc52021-03-03 11:21:43 -08002114 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002115 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002116 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002117 return result_tens
2118
2119 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
2120 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
2121
Kevin Cheng550ccc52021-03-03 11:21:43 -08002122 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002123 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002124 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002125 return result_tens
2126
2127 def build_matmul(self, op, a, b, qinfo):
2128 result_tens = OutputShaper.matmulOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002129 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002130 return result_tens
2131
2132 def build_reduce(self, op, a, axis):
2133 result_tens = OutputShaper.reduceOp(self.ser, a, axis)
2134
2135 attr = ts.TosaSerializerAttribute()
2136 attr.AxisAttribute(axis)
2137
Matthew Haddon848efb42021-09-09 12:30:53 +01002138 self.ser.addOperator(op['op'], [a.name], result_tens.name, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002139 return result_tens
2140
2141 def build_clamp(self, op, a):
2142 result_tens = OutputShaper.unaryOp(self.ser, a)
2143
2144 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01002145 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07002146
2147 if a.dtype == DType.FLOAT:
2148 attr.ClampAttribute(0, 0, min(v), max(v))
2149 else:
2150 attr.ClampAttribute(min(v), max(v), 0, 0)
2151
Matthew Haddon848efb42021-09-09 12:30:53 +01002152 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002153 return result_tens
2154
2155 def build_leaky_relu(self, op, a):
2156 result_tens = OutputShaper.unaryOp(self.ser, a)
2157 attr = ts.TosaSerializerAttribute()
2158
2159 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
2160
Matthew Haddon848efb42021-09-09 12:30:53 +01002161 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002162 return result_tens
2163
2164 # Needs an additional type/input
2165 def build_prelu(self, op, a):
2166 result_tens = OutputShaper.unaryOp(self.ser, a)
2167
Matthew Haddon848efb42021-09-09 12:30:53 +01002168 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002169 return result_tens
2170
Eric Kunzee5e26762020-10-13 16:11:07 -07002171 def build_sigmoid(self, op, a):
2172 result_tens = OutputShaper.unaryOp(self.ser, a)
Matthew Haddon848efb42021-09-09 12:30:53 +01002173 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002174 return result_tens
2175
2176 def build_tanh(self, op, a):
2177 result_tens = OutputShaper.unaryOp(self.ser, a)
Matthew Haddon848efb42021-09-09 12:30:53 +01002178 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002179 return result_tens
2180
Matthew Haddon818ab902021-07-27 09:12:49 +01002181 def build_concat(self, op, *a):
Kevin Cheng93a16282021-08-31 16:14:03 -07002182 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01002183
2184 # To store variable length list of input tensors we need to store axis along with it
2185 axis = a[-1]
2186 a = a[:-1]
2187
2188 result_tens = OutputShaper.concatOp(self.ser, axis, *a)
Eric Kunzee5e26762020-10-13 16:11:07 -07002189
2190 attr = ts.TosaSerializerAttribute()
2191 attr.AxisAttribute(axis)
2192
Matthew Haddon818ab902021-07-27 09:12:49 +01002193 input_tensor_names = []
2194 for tensor in a:
2195 input_tensor_names.append(tensor.name)
2196
Matthew Haddon848efb42021-09-09 12:30:53 +01002197 self.ser.addOperator(op['op'], input_tensor_names, [result_tens.name], attr)
2198 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002199
2200 def build_pad(self, op, a, padding, qinfo):
2201 result_tens = OutputShaper.padOp(self.ser, a, padding)
2202
2203 # Need to turn the padding array into a TOSA tensor here.
2204 # This is one of the few tensor operands that does not get
2205 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08002206 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07002207
Kevin Cheng550ccc52021-03-03 11:21:43 -08002208 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002209 op['op'], [a.name, padding_tens.name], [result_tens.name], None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002210 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002211 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002212
2213 def build_reshape(self, op, a, newShape):
2214 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
2215
2216 attr = ts.TosaSerializerAttribute()
2217 attr.ReshapeAttribute(newShape)
2218
Matthew Haddon848efb42021-09-09 12:30:53 +01002219 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002220 return result_tens
2221
2222 def build_reverse(self, op, a, axis):
2223 result_tens = OutputShaper.unaryOp(self.ser, a)
2224
2225 attr = ts.TosaSerializerAttribute()
2226 attr.AxisAttribute(axis)
2227
Matthew Haddon848efb42021-09-09 12:30:53 +01002228 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002229 return result_tens
2230
2231 def build_transpose(self, op, a, perms):
2232 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
2233
Kevin Cheng550ccc52021-03-03 11:21:43 -08002234 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07002235
Matthew Haddon848efb42021-09-09 12:30:53 +01002236 self.ser.addOperator(op['op'], [a.name, perms_tens.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002237 return result_tens
2238
2239 def build_slice(self, op, a, begin, size):
2240 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
2241
2242 attr = ts.TosaSerializerAttribute()
2243 attr.SliceAttribute(begin, size)
2244
Matthew Haddon848efb42021-09-09 12:30:53 +01002245 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002246 return result_tens
2247
2248 def build_tile(self, op, a, multiples):
2249 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
2250
2251 attr = ts.TosaSerializerAttribute()
2252 attr.TileAttribute(multiples)
2253
Matthew Haddon848efb42021-09-09 12:30:53 +01002254 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002255 return result_tens
2256
Kevin Cheng77d0f762020-11-24 10:26:32 -08002257 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07002258
2259 # Create a new indicies tensor
2260 # here with data that doesn't exceed the dimensions of the values tensor
2261
Kevin Cheng550ccc52021-03-03 11:21:43 -08002262 K = values.shape[1] # K
2263 W = self.randInt(
2264 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
2265 ) # W
2266 indicies_arr = np.int32(
2267 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
2268 ) # (N, W)
2269 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002270
Kevin Cheng77d0f762020-11-24 10:26:32 -08002271 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07002272
Matthew Haddon848efb42021-09-09 12:30:53 +01002273 self.ser.addOperator(op['op'], [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002274
2275 return result_tens
2276
Kevin Cheng77d0f762020-11-24 10:26:32 -08002277 def build_scatter(self, op, values_in, input):
2278
2279 # Create a new indicies tensor
2280 # here with data that doesn't exceed the dimensions of the values_in tensor
2281
Kevin Cheng550ccc52021-03-03 11:21:43 -08002282 K = values_in.shape[1] # K
2283 W = input.shape[1] # W
2284 indicies_arr = np.int32(
2285 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
2286 ) # (N, W)
2287 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002288
2289 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
2290
Kevin Cheng550ccc52021-03-03 11:21:43 -08002291 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002292 op['op'], [values_in.name, indicies.name, input.name], [result_tens.name]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002293 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08002294
2295 return result_tens
2296
Matthew Haddon848efb42021-09-09 12:30:53 +01002297
Kevin Cheng550ccc52021-03-03 11:21:43 -08002298 def build_resize(
2299 self,
2300 op,
2301 input,
2302 mode,
2303 stride,
2304 offset,
2305 shift,
2306 stride_fp,
2307 offset_fp,
2308 output_dims,
2309 input_dtype,
2310 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01002311 validator_fcns,
2312 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002313 ):
2314 result_tens = OutputShaper.resizeOp(
2315 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002316 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002317 input,
2318 mode,
2319 stride,
2320 offset,
2321 shift,
2322 stride_fp,
2323 offset_fp,
2324 output_dims,
2325 input_dtype,
2326 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01002327 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08002328 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002329
Matthew Haddon848efb42021-09-09 12:30:53 +01002330 # Invalidate Input/Output list for error if checks.
2331 input_list = [input.name]
2332 output_list = [result_tens.name]
2333 pCount, cCount = op["operands"]
2334 num_operands = pCount + cCount
2335 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
Matthew Haddone86fd342021-09-07 16:12:21 +01002336
Matthew Haddon848efb42021-09-09 12:30:53 +01002337 TosaErrorValidator.evValidateErrorIfs(
2338 self.ser,
2339 validator_fcns,
2340 error_name,
2341 op=op,
2342 mode=mode,
2343 shift=shift,
2344 input_dtype=input_dtype,
2345 output_dtype=output_dtype,
2346 input_shape=input,
2347 output_shape=output_dims,
2348 offset=offset,
2349 offset_fp=offset_fp,
2350 stride=stride,
2351 stride_fp=stride_fp,
2352 input_list=input_list,
2353 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002354 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01002355 num_operands=num_operands,
2356 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002357
Eric Kunzee5e26762020-10-13 16:11:07 -07002358 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08002359
Kevin Cheng550ccc52021-03-03 11:21:43 -08002360 attr.ResizeAttribute(
2361 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
2362 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002363
Matthew Haddon848efb42021-09-09 12:30:53 +01002364 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002365 return result_tens
2366
2367 def build_identityn(self, op, val, val2):
2368
Kevin Cheng550ccc52021-03-03 11:21:43 -08002369 result_tens = OutputShaper.unaryOp(self.ser, val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002370 result_tens2 = OutputShaper.unaryOp(self.ser, val2)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002371 self.ser.addOperator(
2372 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2373 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002374 return result_tens
2375
Kevin Cheng17e92022021-10-01 14:33:33 -07002376 def build_const(self, op, val):
2377 self.ser.addOutputTensor(val)
2378 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07002379
2380 # Type Conversion
2381 def build_cast(self, op, val, out_dtype):
2382 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01002383 self.ser.addOperator(op['op'], [val.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002384 return result_tens
2385
2386 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
2387 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
2388
2389 if per_channel:
2390 nc = val.shape[-1]
2391 else:
2392 nc = 1
2393
2394 in_type_width = self.typeWidth(val.dtype)
2395 out_type_width = self.typeWidth(out_dtype)
2396
Kevin Cheng3a478572021-01-22 17:21:02 -08002397 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002398 input_zp = self.randInt(-128, 128)
2399 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002400 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002401 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07002402 in_type_width = in_type_width + 1
2403 else:
2404 input_zp = 0
2405
Kevin Cheng3a478572021-01-22 17:21:02 -08002406 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002407 output_zp = self.randInt(-128, 128)
2408 out_type_width = out_type_width + 1
2409 elif out_dtype == DType.UINT8:
2410 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07002411 out_type_width = out_type_width + 1
2412 else:
2413 output_zp = 0
2414
2415 # Calculate scale based on:
2416 # scale = a *(2^output_width)/(2^input_width))
2417
2418 a = np.float32(self.rng.random(size=[nc]))
2419 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2420
2421 if scale32:
2422 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002423 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002424 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2425 else:
2426 # Cap the scaling at 2^15 - 1 for scale16
2427 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2428
Kevin Cheng550ccc52021-03-03 11:21:43 -08002429 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002430
2431 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2432 shift_arr = np.int32(np.zeros(shape=[nc]))
2433
2434 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002435 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2436 scale_arr[i], scale32
2437 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002438
Kevin Cheng550ccc52021-03-03 11:21:43 -08002439 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07002440
2441 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002442 attr.RescaleAttribute(
2443 input_zp,
2444 output_zp,
2445 multiplier_arr,
2446 shift_arr,
2447 scale32,
2448 double_round,
2449 per_channel,
2450 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002451
Matthew Haddon848efb42021-09-09 12:30:53 +01002452 self.ser.addOperator(op['op'], [val.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002453 return result_tens
2454
2455 def build_cond_if_const(self, op, then_tens, else_tens, cond):
2456 # For cond_if with constants, we're supplied with then/else tensors that we ignore
2457 # (except for the generated shap) and the condition. Build Then/Else blocks
2458 # and fill them with const nodes for the body.
2459
2460 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002461 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07002462
2463 # Make then/else tensors
2464 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01002465 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2466 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002467
2468 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002469 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002470
2471 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002472 then_block = "THEN_BLOCK"
2473 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002474 attr = ts.TosaSerializerAttribute()
2475 attr.CondIfAttribute(then_block, else_block)
2476
2477 # Finally, build the op and the two blocks
Matthew Haddon848efb42021-09-09 12:30:53 +01002478 self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002479
2480 self.ser.startBasicBlock(then_block)
2481 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002482 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002483 self.ser.addOutputTensor(then_tens)
2484
2485 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002486 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002487 self.ser.addOutputTensor(else_tens)
2488
2489 return result_tens
2490
2491 def build_cond_if_binary(self, op, a, b, cond):
2492 # For cond_if with a binary op in the then/else blocks, take a and b and
2493 # alternately add or subtract them based on the condition
2494
2495 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002496 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07002497
Kevin Cheng550ccc52021-03-03 11:21:43 -08002498 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002499
2500 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002501 then_block = "THEN_BLOCK"
2502 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002503 attr = ts.TosaSerializerAttribute()
2504 attr.CondIfAttribute(then_block, else_block)
2505
2506 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002507 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002508 op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002509 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002510
2511 self.ser.startBasicBlock(then_block)
2512 self.ser.addInputTensor(a)
2513 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002514 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002515 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
2516
2517 self.ser.startBasicBlock(else_block)
2518 self.ser.addInputTensor(a)
2519 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002520 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002521 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
2522
2523 return result_tens
2524
2525 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002526 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002527
Kevin Cheng550ccc52021-03-03 11:21:43 -08002528 cond_block = "COND_BLOCK"
2529 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002530
2531 attr = ts.TosaSerializerAttribute()
2532 attr.WhileLoopAttribute(cond_block, body_block)
2533
2534 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002535 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002536 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002537 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002538
2539 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002540 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2541 a_out = self.ser.addIntermediate(a.shape, a.dtype)
2542 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002543
2544 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002545 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002546 op['op'],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002547 [iter.name, a.name, acc.name],
2548 [iter_out.name, a_out.name, acc_out.name],
2549 attr,
2550 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002551 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002552
2553 # COND block (input: iter, output: cond_tens )
2554 self.ser.startBasicBlock(cond_block)
2555 self.ser.addInputTensor(iter)
2556 self.ser.addInputTensor(a)
2557 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002558 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
2559 cond_tens = self.ser.addOutput([], DType.BOOL)
2560 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002561
2562 # BODY block (input: a, acc, iter, output: a, acc, iter)
2563 # Note that local intermediate tensors need to be declared here for the outputs
2564 self.ser.startBasicBlock(body_block)
2565 self.ser.addInputTensor(iter)
2566 self.ser.addInputTensor(a)
2567 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002568 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
2569 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2570 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002571 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2572 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2573 self.ser.addOutputTensor(iter_body_out)
2574 self.ser.addOutputTensor(a)
2575 self.ser.addOutputTensor(acc_body_out)
2576
2577 return acc_out
2578
Matthew Haddon1c00b712021-10-01 15:51:03 +01002579 def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
2580 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2581 default_test_rank_range = range(1, 5)
2582 if not shapeFilter:
2583 shapeFilter = [None]
2584
2585 # Calculate the filters based on what is requested and what the operator allows
2586 rmin, rmax = op["rank"]
2587 if rankFilter is not None:
2588 cleanRankFilter = []
2589 # Ensure rankFilter values are allowed by operator
2590 for rank in rankFilter:
2591 if rank >= rmin and rank <= rmax:
2592 cleanRankFilter.append(rank)
2593 elif rankFilter is None and shapeFilter[0] is None:
2594 cleanRankFilter = []
2595 # Ensure default behaviour is bounded by default range or by operator, whichever is smaller.
2596 rankRange = range(rmin, rmax + 1)
2597 for rank in rankRange:
2598 if rank >= min(default_test_rank_range) and rank <= max(default_test_rank_range):
2599 cleanRankFilter.append(rank)
2600 else:
2601 cleanRankFilter = range(rmin, rmax + 1)
2602
2603 dtypes = op["types"]
2604 if dtypeFilter is not None:
2605 cleanDtypeFilter = []
2606 # Ensure filtered dtypes are allowed by operator
2607 for dtype in dtypeFilter:
2608 if dtype in dtypes:
2609 cleanDtypeFilter.append(dtype)
2610 else:
2611 cleanDtypeFilter = dtypes
2612
2613 if testType == 'positive':
2614 filterDict = {
2615 'shapeFilter': shapeFilter,
2616 'rankFilter': cleanRankFilter,
2617 'dtypeFilter': cleanDtypeFilter
2618 }
2619 return filterDict
2620 elif testType == 'negative':
2621 validator_info = validator(check=False, op=op)
2622 error_arguments = validator_info['param_reqs']
2623
2624 #Set parameters as required
2625 if error_arguments['rank'] != None:
2626 rankFilter = error_arguments['rank']
2627 else:
2628 rankFilter = cleanRankFilter
2629
2630 if error_arguments['dtype'] != None:
2631 dtypeFilter = error_arguments['dtype']
2632 else:
2633 dtypeFilter = cleanDtypeFilter
2634
2635 if error_arguments['shape'] != None:
2636 shapeFilter = error_arguments['shape']
2637 else:
2638 shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small
2639
2640 filterDict = {
2641 'shapeFilter': shapeFilter,
2642 'rankFilter': rankFilter,
2643 'dtypeFilter': dtypeFilter
2644 }
2645 return filterDict
2646
2647
Kevin Cheng550ccc52021-03-03 11:21:43 -08002648 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01002649 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08002650 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002651
2652 try:
2653 op = self.TOSA_OP_LIST[opName]
2654 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002655 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002656
2657 # Initialize a new random number generator
2658 self.rng = np.random.default_rng(self.random_seed)
2659
Kevin Cheng550ccc52021-03-03 11:21:43 -08002660 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002661
Eric Kunzee5e26762020-10-13 16:11:07 -07002662 # Test list consists of a tuple of:
2663 # (opName, testNameStr, dtype, shapeList, argumentsList)
2664 testList = []
Matthew Haddon1c00b712021-10-01 15:51:03 +01002665 if testType == 'negative' and "error_if_validators" in op:
2666 error_if_validators = op["error_if_validators"]
2667 else:
2668 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002669
Matthew Haddon1c00b712021-10-01 15:51:03 +01002670 for validator in error_if_validators:
2671 if validator is not None:
2672 error_name = validator(check=False, op=op)['error_name']
2673 #print("error_name: ", error_name)
2674 else:
2675 error_name = None
2676
2677 filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator)
2678 cleanRankFilter = filterDict['rankFilter']
2679 cleanDtypeFilter = filterDict['dtypeFilter']
2680 cleanShapeFilter = filterDict['shapeFilter']
2681 #print(f"Filters: S {shapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
2682
2683 for r in cleanRankFilter:
Kevin Cheng1533b852021-09-01 12:51:58 -07002684 if opName.startswith("conv3d"):
2685 assert r == 5, "conv3d test must have input rank == 5"
Matthew Haddon1c00b712021-10-01 15:51:03 +01002686 for t in cleanDtypeFilter:
2687 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002688 # Filter out by rank
2689 if shape is not None and len(shape) != r:
2690 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002691 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002692 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002693
Matthew Haddon74567092021-07-16 15:38:20 +01002694 shapeStr = self.shapeStr(shapeList[0])
2695 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002696
Matthew Haddon74567092021-07-16 15:38:20 +01002697 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2698 argList = []
2699 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002700 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002701 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002702 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002703
Matthew Haddon74567092021-07-16 15:38:20 +01002704 for argStr, args in argList:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002705 if testType == 'positive':
2706 if argStr:
2707 testStr = "{}_{}_{}_{}".format(
2708 opName, shapeStr, typeStr, argStr
2709 )
2710 else:
2711 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
2712 elif testType == 'negative':
Matthew Haddone86fd342021-09-07 16:12:21 +01002713 if argStr:
2714 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2715 opName, error_name, shapeStr, typeStr, argStr
2716 )
2717 else:
2718 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002719
2720 testList.append((opName, testStr, t, error_name, shapeList, args))
2721
2722 if testType == 'positive':
2723 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2724 if "invalid_test_validators" in op:
2725 invalid_test_validators = op["invalid_test_validators"]
2726 clean_testList = []
2727 for test in testList:
2728 for validator_fcn in invalid_test_validators:
2729 remove_test = False
2730 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
2731 remove_test = True
2732 if not remove_test:
2733 clean_testList.append(test)
2734 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002735
2736 return testList
2737
Matthew Haddone86fd342021-09-07 16:12:21 +01002738
2739 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07002740 try:
2741 op = self.TOSA_OP_LIST[opName]
2742 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002743 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002744
2745 # Create a serializer
2746 self.createSerializer(opName, testStr)
2747
Kevin Cheng550ccc52021-03-03 11:21:43 -08002748 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002749 if "error_if_validators" in op:
2750 error_if_validators = op["error_if_validators"]
2751 else:
2752 error_if_validators = None
2753
Kevin Cheng550ccc52021-03-03 11:21:43 -08002754 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002755 num_operands = pCount + cCount
2756
2757 if isinstance(dtype_or_dtypeList, list):
2758 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002759 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002760 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002761 else:
2762 dtypeList = [dtype_or_dtypeList] * (num_operands)
2763
Kevin Cheng93a16282021-08-31 16:14:03 -07002764 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002765 assert (
2766 len(shapeList) == num_operands
2767 ), "shapeList length {} must match number of operands {}".format(
2768 len(shapeList), num_operands
2769 )
2770 assert (
2771 len(dtypeList) == num_operands
2772 ), "dtypeList length {} must match number of operands {}".format(
2773 len(dtypeList), num_operands
2774 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002775
2776 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002777 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002778 except KeyError:
2779 qgen = None
2780
2781 # Build the random tensor operands and the test
2782 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002783
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002784 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002785
2786 if qgen is not None:
2787 qinfo = qgen(self, op, dtype_or_dtypeList)
2788 else:
2789 qinfo = None
2790
2791 try:
2792 if error_if_validators is None:
2793 if qinfo is not None:
2794 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2795 else:
2796 resultName = build_fcn(self, op, *tens, *testArgs)
2797 else:
2798 if qinfo is not None:
2799 resultName = build_fcn(self, op, *tens, *testArgs, qinfo, error_if_validators, error_name)
2800 else:
2801 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name)
2802 except TypeError as e:
2803 print(
2804 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
2805 build_fcn, tens, testArgs
2806 )
2807 )
2808 raise e
2809
2810 if resultName is None:
2811 print("Invalid ERROR_IF tests created")
2812
2813 # Save the serialized test
2814 self.serialize("test")
2815
2816
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002817 def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002818 pCount, cCount = op["operands"]
2819
2820 tens = []
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002821 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name == None:
Jeremy Johnsonef509a42021-09-07 13:59:47 +01002822 # Make sure the operation does not cause value saturation - where
2823 # the number wraps due to limited number of bits to store the answer
2824 assert (
2825 pCount == 2 and cCount == 0
2826 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01002827 placeholders = []
2828 add = (op["op"] == Op.ADD)
2829 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
2830 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
2831 if add:
2832 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
2833 else:
2834 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
2835
2836 # Work out the saturation limits
2837 max_i32 = (1 << 31)-1
2838 min_i32 = -(1 << 31)
2839 max_arr = np.full(shapeList[1], max_i32)
2840 min_arr = np.full(shapeList[1], min_i32)
2841
2842 # Find how much values exceed the maximum/minimums
2843 sat_max_arr = np.maximum(res_arr - max_arr, 0)
2844 sat_min_arr = np.minimum(res_arr - min_arr, 0)
2845
2846 if not add:
2847 # Swap saturation values and negate values as we need to perform opposite operations
2848 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
2849
2850 # Create new array of unsaturated values by clipping values as needed
2851 b_unsat_arr = b_arr
2852 if (sat_max_arr != 0).any():
2853 # Clip values that cause saturation
2854 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
2855 # Reduce axes in unsaturated tensor to match original tensor
2856 for axis, dim in enumerate(b_arr.shape):
2857 if dim != b_unsat_arr.shape[axis]:
2858 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
2859 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
2860
2861 if (sat_min_arr != 0).any():
2862 # Clip values that cause saturation
2863 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
2864 # Reduce axes in unsaturated tensor to match original tensor
2865 for axis, dim in enumerate(b_arr.shape):
2866 if dim != b_unsat_arr.shape[axis]:
2867 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
2868 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
2869
2870 placeholders.append(
2871 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
2872 )
2873 placeholders.append(
2874 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
2875 )
2876
2877 tens.extend(placeholders)
2878 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
2879 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002880 assert (
2881 pCount == 2 and cCount == 0
2882 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08002883
2884 placeholders = []
2885 for idx, shape in enumerate(shapeList[:]):
2886 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07002887 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002888 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07002889 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002890 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07002891 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002892 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
2893 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002894 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002895 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002896 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07002897 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08002898
2899 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01002900 elif op["op"] == Op.SELECT:
2901 # Set datatype of condition tensor to boolean
2902 dtypeList[0] = DType.BOOL
2903 tens.extend(
2904 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
2905 )
2906 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002907 elif op["op"] == Op.INTDIV and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002908 assert (
2909 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01002910 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002911
2912 placeholders = []
2913
Matthew Haddon459443c2021-08-23 16:43:13 +01002914 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002915 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07002916 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002917 while True:
2918 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
2919 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
2920
2921 if (divisor_arr == 0).any():
2922 continue
2923
Kevin Cheng47315e12021-05-13 17:41:28 -07002924 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002925 continue
2926
2927 break
2928
2929 placeholders.append(
2930 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
2931 )
2932 placeholders.append(
2933 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
2934 )
2935
2936 tens.extend(placeholders)
2937 elif op["op"] == Op.MUL:
2938 assert (
2939 pCount == 2 and cCount == 0
2940 ), "Op.MUL must have 2 placeholders, 0 consts"
2941
2942 if dtypeList[0] == DType.FLOAT:
2943 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
2944 else:
2945 placeholders = []
2946
2947 # Make sure multiply result in int32 range
2948 shift = testArgs[0]
2949 if dtypeList[0] == DType.INT8:
2950 num_bits = 8
2951 elif dtypeList[0] == DType.INT16:
2952 num_bits = 16
2953 elif dtypeList[0] == DType.INT32:
2954 num_bits = 32
2955 else:
2956 raise Exception("OpMul: invalid input dtype")
2957
2958 for idx, shape in enumerate(shapeList[:]):
2959 low = -(2 ** (num_bits - 1))
2960 high = (2 ** (num_bits - 1)) - 1
2961
2962 a_arr = np.int32(
2963 self.rng.integers(low=low, high=high, size=shapeList[0])
2964 )
2965 b_arr = np.int32(
2966 self.rng.integers(low=low, high=high, size=shapeList[1])
2967 )
2968
2969 i = 0
2970 while True:
2971
2972 a_arr_64 = a_arr.astype(np.int64)
2973 b_arr_64 = b_arr.astype(np.int64)
2974
2975 if shift > 0:
2976 rounding = 1 << (shift - 1)
2977 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
2978 else:
2979 result_arr = a_arr_64 * b_arr_64
2980
2981 if (result_arr > -(2 ** 31)).all() and (
2982 result_arr <= ((2 ** 31) - 1)
2983 ).all():
2984 break
2985
2986 i = i + 1
2987 a_arr = a_arr // 2
2988 b_arr = b_arr // 2
2989
2990 placeholders.append(
2991 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
2992 )
2993 placeholders.append(
2994 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
2995 )
2996
2997 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01002998 elif op["op"] == Op.CONCAT:
2999 count = len(shapeList) - self.args.num_const_inputs_concat
3000 if count < 1:
3001 count = 1
3002 if self.args.num_const_inputs_concat == 0:
3003 count = len(shapeList)
3004
3005 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
3006 tens.extend(
3007 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
3008 )
3009 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08003010 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07003011 tens.extend(
3012 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
3013 )
3014 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07003015
Matthew Haddon1c00b712021-10-01 15:51:03 +01003016 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07003017
3018 def createDynamicOpLists(self):
3019
3020 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07003021 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003022
Kevin Cheng1533b852021-09-01 12:51:58 -07003023 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003024 testName = "conv2d_{}x{}".format(k[0], k[1])
3025 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3026 self.TOSA_OP_LIST[testName]["filter"] = k
3027 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003028
Kevin Cheng550ccc52021-03-03 11:21:43 -08003029 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3030 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3031 "depthwise_conv2d_TEMPLATE"
3032 ].copy()
3033 self.TOSA_OP_LIST[testName]["filter"] = k
3034 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003035
Kevin Cheng550ccc52021-03-03 11:21:43 -08003036 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3037 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3038 "transpose_conv2d_TEMPLATE"
3039 ].copy()
3040 self.TOSA_OP_LIST[testName]["filter"] = k
3041 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003042
Kevin Cheng1533b852021-09-01 12:51:58 -07003043 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3044 for k in KERNELS_3D:
3045 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3046 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3047 self.TOSA_OP_LIST[testName]["filter"] = k
3048 self.TOSA_OP_LIST[testName]["template"] = False
3049
Eric Kunzee5e26762020-10-13 16:11:07 -07003050 # Delete any templates after having created any dynamic ops
3051 # This is a two-pass operation because it's bad practice to delete
3052 # keys from dictionaries while iterating
3053 keyList = []
3054 for k in self.TOSA_OP_LIST:
3055 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003056 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07003057 keyList.append(k)
3058 continue
3059 except KeyError:
3060 pass
3061
3062 for k in keyList:
3063 del self.TOSA_OP_LIST[k]
3064
3065 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003066 """Fill in default fields for ops if they aren't already specified.
3067 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003068 for op in self.TOSA_OP_LIST:
3069
3070 # Required fields
3071 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003072 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003073 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003074 raise Exception(
3075 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3076 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003077
3078 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003079 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003080 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003081 raise Exception(
3082 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3083 op
3084 )
3085 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003086
3087 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003088 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003089 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003090 raise Exception(
3091 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3092 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003093
3094 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003095 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003096 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003097 raise Exception(
3098 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3099 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003100
3101 # Put in default rank range, if missing
3102 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003103 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003104 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003105 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003106
3107 # Tensor operator list
3108 # 'op': op name
3109 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003110 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3111 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003112 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3113 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08003114 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07003115
Kevin Cheng550ccc52021-03-03 11:21:43 -08003116 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
3117 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003118
Kevin Cheng550ccc52021-03-03 11:21:43 -08003119 TYPE_BOOL = [DType.BOOL]
3120 TYPE_FI32 = [DType.FLOAT, DType.INT32]
3121 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
3122 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003123
Kevin Cheng550ccc52021-03-03 11:21:43 -08003124 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07003125
Kevin Cheng1533b852021-09-01 12:51:58 -07003126 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003127 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003128 [DType.INT8, DType.INT8, DType.INT32],
3129 [DType.INT16, DType.INT8, DType.INT48],
3130 DType.FLOAT,
3131 ]
3132
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003133 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003134
3135 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003136 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003137 "argmax": {
3138 "op": Op.ARGMAX,
3139 "operands": (1, 0),
3140 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3141 "types": TYPE_NARROW_INT_FP,
3142 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003143 "avg_pool2d": {
3144 "op": Op.AVG_POOL2D,
3145 "operands": (1, 0),
3146 "rank": (4, 4),
3147 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
3148 "qgen": TosaQuantGen.qgUnary,
3149 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003150 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
Jared Smolens573ecd42021-03-04 15:24:10 -08003151 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003152 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003153 "conv2d_TEMPLATE": {
3154 "op": Op.CONV2D,
3155 "operands": (1, 2),
3156 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01003157 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003158 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003159 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003160 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003161 "template": True,
3162 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003163 # Templated operator. Filled in by createDynamicOpLists
3164 "conv3d_TEMPLATE": {
3165 "op": Op.CONV3D,
3166 "operands": (1, 2),
3167 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01003168 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07003169 "qgen": TosaQuantGen.qgConv,
3170 "types": TYPE_CONV,
3171 "template": True,
3172 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003173 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003174 "depthwise_conv2d_TEMPLATE": {
3175 "op": Op.DEPTHWISE_CONV2D,
3176 "operands": (1, 2),
3177 "filter": [1, 1],
3178 "rank": (4, 4),
3179 "build_fcn": (
3180 build_depthwise_conv2d,
3181 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01003182 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003183 ),
3184 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003185 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003186 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003187 "template": True,
3188 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003189 "fully_connected": {
3190 "op": Op.FULLY_CONNECTED,
3191 "operands": (1, 2),
3192 "rank": (2, 2),
3193 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
3194 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003195 "types": TYPE_CONV,
Jared Smolens573ecd42021-03-04 15:24:10 -08003196 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003197 "matmul": {
3198 "op": Op.MATMUL,
3199 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003200 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08003201 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
3202 "qgen": TosaQuantGen.qgMatmul,
3203 "types": TYPE_NARROW_INT_FP,
3204 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003205 "max_pool2d": {
3206 "op": Op.MAX_POOL2D,
3207 "operands": (1, 0),
3208 "rank": (4, 4),
3209 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
3210 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003211 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
Jared Smolens573ecd42021-03-04 15:24:10 -08003212 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003213 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003214 "transpose_conv2d_TEMPLATE": {
3215 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003216 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003217 "rank": (4, 4),
3218 "build_fcn": (
3219 build_transpose_conv2d,
3220 TosaTensorGen.tgTransposeConv2D,
3221 TosaArgGen.agTransposeConv2D,
3222 ),
3223 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003224 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003225 "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003226 "template": True,
3227 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003228 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003229 "clamp": {
3230 "op": Op.CLAMP,
3231 "operands": (1, 0),
3232 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
3233 "types": TYPE_NARROW_INT_FP,
3234 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003235 "sigmoid": {
3236 "op": Op.SIGMOID,
3237 "operands": (1, 0),
3238 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
3239 "types": TYPE_FP,
3240 },
3241 "tanh": {
3242 "op": Op.TANH,
3243 "operands": (1, 0),
3244 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
3245 "types": TYPE_FP,
3246 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003247 # Elementwise Binary Operators
3248 "add": {
3249 "op": Op.ADD,
3250 "operands": (2, 0),
3251 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3252 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003253 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3254 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003255 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003256 "arithmetic_right_shift": {
3257 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3258 "operands": (2, 0),
3259 "build_fcn": (
3260 build_arithmetic_right_shift,
3261 TosaTensorGen.tgBroadcastFuzz,
3262 TosaArgGen.agArithmeticRightShift,
3263 ),
3264 "types": TYPE_INT,
3265 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003266 "bitwise_and": {
3267 "op": Op.BITWISE_AND,
3268 "operands": (2, 0),
3269 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3270 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003271 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3272 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003273 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003274 "bitwise_or": {
3275 "op": Op.BITWISE_OR,
3276 "operands": (2, 0),
3277 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3278 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003279 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3280 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003281 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003282 "bitwise_xor": {
3283 "op": Op.BITWISE_XOR,
3284 "operands": (2, 0),
3285 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3286 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003287 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3288 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003289 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003290 "intdiv": {
3291 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003292 "operands": (2, 0),
3293 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3294 "types": [DType.INT32],
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003295 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3296 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003297 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003298 "logical_and": {
3299 "op": Op.LOGICAL_AND,
3300 "operands": (2, 0),
3301 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3302 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003303 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3304 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003305 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003306 "logical_left_shift": {
3307 "op": Op.LOGICAL_LEFT_SHIFT,
3308 "operands": (2, 0),
3309 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3310 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003311 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3312 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003313 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003314 "logical_right_shift": {
3315 "op": Op.LOGICAL_RIGHT_SHIFT,
3316 "operands": (2, 0),
3317 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3318 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003319 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3320 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003321 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003322 "logical_or": {
3323 "op": Op.LOGICAL_OR,
3324 "operands": (2, 0),
3325 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3326 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003327 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3328 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003329 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003330 "logical_xor": {
3331 "op": Op.LOGICAL_XOR,
3332 "operands": (2, 0),
3333 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3334 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003335 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3336 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003337 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003338 "maximum": {
3339 "op": Op.MAXIMUM,
3340 "operands": (2, 0),
3341 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3342 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003343 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3344 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003345 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003346 "minimum": {
3347 "op": Op.MINIMUM,
3348 "operands": (2, 0),
3349 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3350 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003351 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3352 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003353 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003354 "mul": {
3355 "op": Op.MUL,
3356 "operands": (2, 0),
3357 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
3358 "types": TYPE_INT_FP,
3359 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003360 "pow": {
3361 "op": Op.POW,
3362 "operands": (2, 0),
3363 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
3364 "types": TYPE_FP,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003365 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3366 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003367 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003368 "sub": {
3369 "op": Op.SUB,
3370 "operands": (2, 0),
3371 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3372 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003373 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3374 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003375 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003376 "table": {
3377 "op": Op.TABLE,
3378 # Use the automatic generation functions to create the input array
3379 # but create the table tensor in the build function, as it may be
3380 # a different type from the input
3381 "operands": (1, 0),
3382 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003383 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08003384 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003385 # Elementwise Unary operators
3386 "abs": {
3387 "op": Op.ABS,
3388 "operands": (1, 0),
3389 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3390 "types": TYPE_FI32,
3391 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003392 "bitwise_not": {
3393 "op": Op.BITWISE_NOT,
3394 "operands": (1, 0),
3395 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3396 "types": TYPE_INT,
3397 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003398 "ceil": {
3399 "op": Op.CEIL,
3400 "operands": (1, 0),
3401 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3402 "types": TYPE_FP,
3403 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003404 "clz": {
3405 "op": Op.CLZ,
3406 "operands": (1, 0),
3407 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3408 "types": [DType.INT32],
3409 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003410 "exp": {
3411 "op": Op.EXP,
3412 "operands": (1, 0),
3413 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3414 "types": TYPE_FP,
3415 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003416 "floor": {
3417 "op": Op.FLOOR,
3418 "operands": (1, 0),
3419 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3420 "types": TYPE_FP,
3421 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003422 "log": {
3423 "op": Op.LOG,
3424 "operands": (1, 0),
3425 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3426 "types": TYPE_FP,
3427 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003428 "logical_not": {
3429 "op": Op.LOGICAL_NOT,
3430 "operands": (1, 0),
3431 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3432 "types": TYPE_BOOL,
3433 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003434 "negate": {
3435 "op": Op.NEGATE,
3436 "operands": (1, 0),
3437 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3438 "qgen": TosaQuantGen.qgUnary,
3439 "types": TYPE_INT_FP,
3440 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003441 "reciprocal": {
3442 "op": Op.RECIPROCAL,
3443 "operands": (1, 0),
3444 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3445 "types": TYPE_FP,
3446 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003447 "rsqrt": {
3448 "op": Op.RSQRT,
3449 "operands": (1, 0),
3450 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3451 "types": TYPE_FP,
3452 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003453 # Elementwise Ternary operators
3454 "select": {
3455 "op": Op.SELECT,
3456 "operands": (3, 0),
3457 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
3458 "types": TYPE_FIB,
3459 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003460 # Comparison operators
3461 "equal": {
3462 "op": Op.EQUAL,
3463 "operands": (2, 0),
3464 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3465 "types": TYPE_FI32,
3466 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003467 "greater_equal": {
3468 "op": Op.GREATER_EQUAL,
3469 "operands": (2, 0),
3470 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3471 "types": TYPE_FI32,
3472 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003473 "greater": {
3474 "op": Op.GREATER,
3475 "operands": (2, 0),
3476 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3477 "types": TYPE_FI32,
3478 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003479 # Reduction operators
3480 "reduce_all": {
3481 "op": Op.REDUCE_ALL,
3482 "operands": (1, 0),
3483 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3484 "types": TYPE_BOOL,
3485 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003486 "reduce_any": {
3487 "op": Op.REDUCE_ANY,
3488 "operands": (1, 0),
3489 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3490 "types": TYPE_BOOL,
3491 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003492 "reduce_max": {
3493 "op": Op.REDUCE_MAX,
3494 "operands": (1, 0),
3495 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3496 "types": TYPE_INT_FP,
3497 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003498 "reduce_min": {
3499 "op": Op.REDUCE_MAX,
3500 "operands": (1, 0),
3501 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3502 "types": TYPE_INT_FP,
3503 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003504 "reduce_product": {
3505 "op": Op.REDUCE_PRODUCT,
3506 "operands": (1, 0),
3507 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3508 "types": TYPE_FP,
3509 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003510 "reduce_sum": {
3511 "op": Op.REDUCE_SUM,
3512 "operands": (1, 0),
3513 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3514 "types": TYPE_FI32,
3515 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003516 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003517 "concat": {
3518 "op": Op.CONCAT,
3519 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01003520 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003521 "types": TYPE_FIB,
3522 },
3523 "pad": {
3524 "op": Op.PAD,
3525 "operands": (1, 0),
3526 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
3527 "qgen": TosaQuantGen.qgPad,
3528 "types": TYPE_FIB,
3529 },
3530 "reshape": {
3531 "op": Op.RESHAPE,
3532 "operands": (1, 0),
3533 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
3534 "types": TYPE_FIB,
3535 },
3536 "reverse": {
3537 "op": Op.REVERSE,
3538 "operands": (1, 0),
3539 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3540 "types": TYPE_FIB,
3541 },
3542 "slice": {
3543 "op": Op.SLICE,
3544 "operands": (1, 0),
3545 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
3546 "types": TYPE_FIB,
3547 },
3548 "tile": {
3549 "op": Op.TILE,
3550 "operands": (1, 0),
3551 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
3552 "types": TYPE_FIB,
3553 },
3554 "transpose": {
3555 "op": Op.TRANSPOSE,
3556 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003557 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003558 "build_fcn": (
3559 build_transpose,
3560 TosaTensorGen.tgBasic,
3561 TosaArgGen.agTranspose,
3562 ),
3563 "types": TYPE_FIB,
3564 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003565 # Data nodes
3566 "const": {
3567 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003568 "operands": (0, 1),
3569 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08003570 "types": TYPE_FIB,
3571 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003572 "identity": {
3573 "op": Op.IDENTITY,
3574 "operands": (1, 0),
3575 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3576 "types": TYPE_FIB,
3577 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003578 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003579 "gather": {
3580 "op": Op.GATHER,
3581 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3582 "operands": (1, 0),
3583 "rank": (3, 3),
3584 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
3585 "types": TYPE_INT_FP,
3586 },
3587 "scatter": {
3588 "op": Op.SCATTER,
3589 # Only specify 'values_in' tensor here.
3590 #'indices' and 'input' are generated in op building stage
3591 "operands": (2, 0),
3592 "rank": (3, 3),
3593 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
3594 "types": TYPE_INT_FP,
3595 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003596 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003597 "resize": {
3598 "op": Op.RESIZE,
3599 "operands": (1, 0),
3600 "rank": (4, 4),
3601 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
3602 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01003603 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
3604 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
3605 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
Matthew Haddon848efb42021-09-09 12:30:53 +01003606 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003607 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
3608 TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003609 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003610 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003611 "cast": {
3612 "op": Op.CAST,
3613 "operands": (1, 0),
3614 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
3615 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
3616 },
3617 "rescale": {
3618 "op": Op.RESCALE,
3619 "operands": (1, 0),
3620 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003621 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Kevin Cheng550ccc52021-03-03 11:21:43 -08003622 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003623 # Custom
3624 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003625 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003626 # Two varients of cond_if, one that generates one of two constant tensors (no
3627 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3628 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003629 "cond_if_const": {
3630 "op": Op.COND_IF,
3631 "operands": (0, 2),
3632 "build_fcn": (
3633 build_cond_if_const,
3634 TosaTensorGen.tgBasic,
3635 TosaArgGen.agCondIf,
3636 ),
3637 "types": [DType.BOOL],
3638 },
3639 "cond_if_binary": {
3640 "op": Op.COND_IF,
3641 "operands": (2, 0),
3642 "build_fcn": (
3643 build_cond_if_binary,
3644 TosaTensorGen.tgBasic,
3645 TosaArgGen.agCondIf,
3646 ),
3647 "types": TYPE_FI32,
3648 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003649 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003650 "while_loop": {
3651 "op": Op.WHILE_LOOP,
3652 "operands": (0, 1),
3653 "build_fcn": (
3654 build_while_loop,
3655 TosaTensorGen.tgBasic,
3656 TosaArgGen.agWhileLoop,
3657 ),
3658 "types": [DType.INT32],
3659 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003660 }
3661
Kevin Cheng550ccc52021-03-03 11:21:43 -08003662
Eric Kunzee5e26762020-10-13 16:11:07 -07003663class OutputShaper:
3664 # Methods in this class compute the expected output shape and datatype
3665 # for common classes of operations
3666 def __init__(self):
3667 pass
3668
3669 # These methods return arguments that can be used for
3670 # creating a new output tensor
3671 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003672 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
3673 if error_name != ErrorIf.RankMismatch:
3674 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003675 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003676
3677 shape = []
3678 for i in range(len(a.shape)):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003679 if a.shape[i] == 1 and error_name == None:
Eric Kunzee5e26762020-10-13 16:11:07 -07003680 shape.append(b.shape[i])
3681 else:
3682 shape.append(a.shape[i])
3683
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003684 if error_name == ErrorIf.WrongOutputType:
3685 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
3686 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3687 outputDType = rng.choice(wrong_dtypes)
3688 else:
3689 outputDType = a.dtype
3690
3691 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003692
3693 @staticmethod
3694 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003695 assert len(a.shape) == len(b.shape)
3696 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003697
3698 shape = []
3699 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003700 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07003701 shape.append(a.shape[i])
3702
Kevin Cheng550ccc52021-03-03 11:21:43 -08003703 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003704
3705 @staticmethod
3706 def unaryOp(ser, a):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003707 return ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003708
3709 @staticmethod
3710 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003711 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
3712 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003713
3714 shape = []
3715 for i in range(len(a.shape)):
3716 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
3717
Kevin Cheng550ccc52021-03-03 11:21:43 -08003718 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003719
3720 @staticmethod
3721 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003722 assert len(a.shape) == len(b.shape)
3723 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003724
3725 # Do broadcast
3726 shape = []
3727 for i in range(len(a.shape)):
3728 if a.shape[i] == 1:
3729 shape.append(b.shape[i])
3730 else:
3731 shape.append(a.shape[i])
3732
3733 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08003734 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07003735
3736 @staticmethod
3737 def reduceOp(ser, a, axis):
3738
3739 shape = a.shape.copy()
3740
3741 shape[axis] = 1
3742
Kevin Cheng550ccc52021-03-03 11:21:43 -08003743 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003744
3745 @staticmethod
3746 def argmaxOp(ser, a, axis):
3747 shape = a.shape.copy()
3748 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003749 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07003750
3751 @staticmethod
3752 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
3753
3754 # IFM: NHWC
3755 # Filter: OHWI
3756 # OFM: NHWC
3757
3758 if len(padding) == 2:
3759 # Expand padding to 4 parameters in the case of transpose_conv2d
3760 # From H,W to T,B,L,R
3761 padding = [padding[0], padding[0], padding[1], padding[1]]
3762
Kevin Cheng550ccc52021-03-03 11:21:43 -08003763 h = (
3764 ifm.shape[1]
3765 - filter.shape[1]
3766 - (filter.shape[1] - 1) * (dilations[0] - 1)
3767 + padding[0]
3768 + padding[1]
3769 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003770
Kevin Cheng550ccc52021-03-03 11:21:43 -08003771 w = (
3772 ifm.shape[2]
3773 - filter.shape[2]
3774 - (filter.shape[2] - 1) * (dilations[1] - 1)
3775 + padding[2]
3776 + padding[3]
3777 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003778
Eric Kunzee5e26762020-10-13 16:11:07 -07003779 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
3780
Kevin Cheng3a478572021-01-22 17:21:02 -08003781 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003782 out_dtype = DType.INT32
3783 elif ifm.dtype == DType.INT16:
3784 out_dtype = DType.INT48
3785 elif ifm.dtype == DType.FLOAT:
3786 out_dtype = DType.FLOAT
3787 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003788 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003789
Kevin Cheng550ccc52021-03-03 11:21:43 -08003790 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003791
3792 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -07003793 def conv3dOp(ser, ifm, filter, strides, padding, dilations):
3794
3795 # IFM: NDHWC
3796 # Filter: ODHWI
3797 # OFM: NDHWC
3798
3799 d = (
3800 ifm.shape[1]
3801 - filter.shape[1]
3802 - (filter.shape[1] - 1) * (dilations[0] - 1)
3803 + padding[0]
3804 + padding[1]
3805 ) // strides[0] + 1
3806
3807 h = (
3808 ifm.shape[2]
3809 - filter.shape[2]
3810 - (filter.shape[2] - 1) * (dilations[1] - 1)
3811 + padding[2]
3812 + padding[3]
3813 ) // strides[1] + 1
3814
3815 w = (
3816 ifm.shape[3]
3817 - filter.shape[3]
3818 - (filter.shape[3] - 1) * (dilations[2] - 1)
3819 + padding[4]
3820 + padding[5]
3821 ) // strides[2] + 1
3822
3823 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
3824
3825 if ifm.dtype == DType.INT8:
3826 out_dtype = DType.INT32
3827 elif ifm.dtype == DType.INT16:
3828 out_dtype = DType.INT48
3829 elif ifm.dtype == DType.FLOAT:
3830 out_dtype = DType.FLOAT
3831 else:
3832 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
3833
3834 return ser.addOutput(ofm_shape, out_dtype)
3835
3836 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -07003837 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
3838 # IFM: NHWC
3839 # Filter: HWCM
3840 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08003841 h = (
3842 ifm.shape[1]
3843 - filter.shape[0]
3844 - (filter.shape[0] - 1) * (dilations[0] - 1)
3845 + padding[0]
3846 + padding[1]
3847 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003848
Kevin Cheng550ccc52021-03-03 11:21:43 -08003849 w = (
3850 ifm.shape[2]
3851 - filter.shape[1]
3852 - (filter.shape[1] - 1) * (dilations[1] - 1)
3853 + padding[2]
3854 + padding[3]
3855 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003856
Eric Kunzee5e26762020-10-13 16:11:07 -07003857 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
3858
Kevin Cheng3a478572021-01-22 17:21:02 -08003859 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003860 out_dtype = DType.INT32
3861 elif ifm.dtype == DType.INT16:
3862 out_dtype = DType.INT48
3863 elif ifm.dtype == DType.FLOAT:
3864 out_dtype = DType.FLOAT
3865 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003866 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003867
Kevin Cheng550ccc52021-03-03 11:21:43 -08003868 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003869
3870 @staticmethod
3871 def pool2dOp(ser, ifm, kernel, stride, pad):
3872 # input: NHWC
3873 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
3874 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
3875
Eric Kunzee5e26762020-10-13 16:11:07 -07003876 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003877 return ser.addOutput(ofm_shape, ifm.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003878
3879 @staticmethod
3880 def fullyConnectedOp(ser, input, filter):
3881 # input: N, IC
3882 # filter: OC, IC
3883 # output: N, OC
3884
3885 output_shape = [input.shape[0], filter.shape[0]]
3886
Kevin Cheng3a478572021-01-22 17:21:02 -08003887 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003888 out_dtype = DType.INT32
3889 elif input.dtype == DType.INT16:
3890 out_dtype = DType.INT48
3891 elif input.dtype == DType.FLOAT:
3892 out_dtype = DType.FLOAT
3893 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003894 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003895
Kevin Cheng550ccc52021-03-03 11:21:43 -08003896 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003897
3898 @staticmethod
3899 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07003900 # a: N, H, C
3901 # b: N, C, W
3902 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07003903
Kevin Cheng2d60f002021-06-09 14:18:32 -07003904 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003905
Kevin Cheng3a478572021-01-22 17:21:02 -08003906 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003907 out_dtype = DType.INT32
3908 elif a.dtype == DType.INT16:
3909 out_dtype = DType.INT48
3910 elif a.dtype == DType.FLOAT:
3911 out_dtype = DType.FLOAT
3912 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003913 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003914
Kevin Cheng550ccc52021-03-03 11:21:43 -08003915 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003916
3917 @staticmethod
Matthew Haddon818ab902021-07-27 09:12:49 +01003918 def concatOp(ser, axis, *a):
3919 input1 = a[0]
3920 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07003921
Matthew Haddon818ab902021-07-27 09:12:49 +01003922 output_shape = input1.shape.copy()
Eric Kunzee5e26762020-10-13 16:11:07 -07003923
Matthew Haddon818ab902021-07-27 09:12:49 +01003924 output_shape[axis] = input1.shape[axis]
3925
3926 for tensor in remaining_inputs:
3927 output_shape[axis] += tensor.shape[axis]
3928
3929 return ser.addOutput(output_shape, input1.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003930
3931 @staticmethod
3932 def padOp(ser, a, padding):
3933
3934 output_shape = a.shape.copy()
3935
3936 for i in range(len(output_shape)):
3937 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
3938
Kevin Cheng550ccc52021-03-03 11:21:43 -08003939 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003940
3941 @staticmethod
3942 def reshapeOp(ser, a, shape):
3943 output_shape = shape.copy()
3944
3945 totalElements = 1
3946 for i in a.shape:
3947 totalElements *= i
3948
3949 # If there are any -1 elements, figure out what that dimension must be
3950 totalOutputElements = 1
3951 for i in output_shape:
3952 if i != -1:
3953 totalOutputElements *= i
3954
3955 # And fill it in
3956 for i in range(len(output_shape)):
3957 if output_shape[i] == -1:
3958 output_shape[i] = totalElements // totalOutputElements
3959
Kevin Cheng550ccc52021-03-03 11:21:43 -08003960 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003961
3962 @staticmethod
3963 def sliceOp(ser, a, begin, size):
3964
3965 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003966 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003967
3968 @staticmethod
3969 def tileOp(ser, a, multiples):
3970
3971 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003972 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003973
3974 for i in range(len(output_shape)):
3975 output_shape[i] = a.shape[i] * multiples[i]
3976
Kevin Cheng550ccc52021-03-03 11:21:43 -08003977 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003978
3979 @staticmethod
3980 def transposeOp(ser, a, perms):
3981 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003982 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003983
3984 for i in range(len(output_shape)):
3985 output_shape[i] = a.shape[perms[i]]
3986
Kevin Cheng550ccc52021-03-03 11:21:43 -08003987 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003988
3989 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08003990 def gatherOp(ser, values, indices):
3991 assert len(values.shape) == 3
3992 assert len(indices.shape) == 2
3993 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07003994
Kevin Cheng77d0f762020-11-24 10:26:32 -08003995 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
3996
Kevin Cheng550ccc52021-03-03 11:21:43 -08003997 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08003998
3999 @staticmethod
4000 def scatterOp(ser, values_in, indices, input):
4001 assert len(values_in.shape) == 3
4002 assert len(indices.shape) == 2
4003 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004004 assert values_in.shape[0] == indices.shape[0] # N
4005 assert input.shape[1] == indices.shape[1] # W
4006 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004007
4008 output_shape = values_in.shape
4009
Kevin Cheng550ccc52021-03-03 11:21:43 -08004010 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004011
4012 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004013 def tableOp(ser, input, table_dtype):
4014 # Same shape as the input, but dtype dependent on table dtype
4015 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
4016 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
4017 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004018
4019 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004020 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004021 serializer,
4022 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004023 input,
4024 mode,
4025 stride,
4026 offset,
4027 shift,
4028 stride_fp,
4029 offset_fp,
4030 output_dims,
4031 input_dtype,
4032 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01004033 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08004034 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01004035 if error_name == ErrorIf.WrongRank:
4036 output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
4037 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004038 if error_name == ErrorIf.BatchMismatch:
4039 output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
4040 elif error_name == ErrorIf.ChannelMismatch:
4041 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
4042 else:
4043 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004044
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004045 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004046
4047 @staticmethod
4048 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004049 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004050
4051 @staticmethod
4052 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08004053 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004054 out_dtype = DType.INT32
4055 elif ifm.dtype == DType.INT16:
4056 out_dtype = DType.INT48
4057 elif ifm.dtype == DType.FLOAT:
4058 out_dtype = DType.FLOAT
4059 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004060 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004061
Kevin Cheng550ccc52021-03-03 11:21:43 -08004062 return ser.addOutput(output_shape, out_dtype)