blob: b9cca18d2dfc3a7c0036dc03aae885c329578574 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
48DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080049Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070050ResizeMode = tosa.ResizeMode.ResizeMode()
51
52class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
54
Eric Kunzee5e26762020-10-13 16:11:07 -070055 def __init__(self):
56 pass
57
58 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010059 def getQinfo(testGen, dtype, error_name=None):
60
Les Bell30e46802021-07-23 09:43:31 +010061 if dtype == DType.INT8:
62 return testGen.randInt(-128, 128)
Matthew Haddone4ecdb22021-09-28 11:38:21 +010063 elif dtype == DType.UINT8:
Les Bell30e46802021-07-23 09:43:31 +010064 return testGen.randInt(0, 256)
Matthew Haddonc4cf0372021-10-11 09:38:10 +010065 elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.WeightZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
Matthew Haddone4ecdb22021-09-28 11:38:21 +010066 zero_point = testGen.randInt(-128, 128)
67 if zero_point == 0:
68 zero_point = 1
69 return zero_point
Les Bell30e46802021-07-23 09:43:31 +010070 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070071
72 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010073 def qgUnary(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070074 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone4ecdb22021-09-28 11:38:21 +010075 if error_name == ErrorIf.InputZeroPointNotZero:
76 qinfo.UnaryQuantInfo(
77 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype)
78 )
79 elif error_name == ErrorIf.OutputZeroPointNotZero:
80 qinfo.UnaryQuantInfo(
81 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype, error_name)
82 )
83 else:
84 qinfo.UnaryQuantInfo(
85 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
86 )
Eric Kunzee5e26762020-10-13 16:11:07 -070087 return qinfo
88
89 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010090 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070091 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010092 if isinstance(dtype_or_dtypeList, list):
93 # a list of [input, weights, accumulator] dtypes
94 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070095 else:
Les Bell30e46802021-07-23 09:43:31 +010096 # an int, [input, weights, accumulator] dtypes are the same
97 dtypeList = [dtype_or_dtypeList] * 3
Matthew Haddonc4cf0372021-10-11 09:38:10 +010098
99 if error_name == ErrorIf.InputZeroPointNotZero:
100 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
101 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
102 elif error_name == ErrorIf.WeightZeroPointNotZero:
103 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
104 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
105 else:
106 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
107 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
108
Les Bell30e46802021-07-23 09:43:31 +0100109 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -0700110 return qinfo
111
112 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100113 def qgMatmul(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700114 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100115 if error_name == ErrorIf.InputZeroPointNotZero:
116 qinfo.MatMulQuantInfo(
117 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype, error_name)
Kevin Chengacb550f2021-06-29 15:32:19 -0700118 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100119 else:
120 qinfo.MatMulQuantInfo(
121 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
122 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700123 return qinfo
124
125 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100126 def qgPad(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700127 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone807aae2021-10-11 18:12:58 +0100128 if error_name == ErrorIf.InputZeroPointNotZero:
129 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name))
130 else:
131 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700132 return qinfo
133
134 @staticmethod
135 def computeMultiplierAndShift(scaleFp, scale32):
136 # Derived from computeMultiplierAndShiftTosaScale32
137 # Provide a floating-point scaling factor and the scale32 parameter
138 # to compute the multiplier and shift
139
140 if scale32:
141 scaleBits = 31
142 else:
143 scaleBits = 15
144
145 m, shift = math.frexp(scaleFp)
146
147 if scaleFp < 0.0:
148 m = -m
149
150 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800151 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700152
153 if multiplier == (1 << scaleBits):
154 multiplier = multiplier // 2
155 shift = shift + 1
156
157 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100158 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
159
160 # Adjust multiplier such that shift is in allowed value range.
161 if shift == 0:
162 multiplier = multiplier // 4
163 shift = shift + 2
164 elif shift == 1:
165 multiplier = multiplier // 2
166 shift = shift + 1
167 elif shift == 63:
168 multiplier = multiplier * 2
169 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700170
Kevin Cheng550ccc52021-03-03 11:21:43 -0800171 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100172 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700173
174 return multiplier, shift
175
176
Kevin Cheng550ccc52021-03-03 11:21:43 -0800177class TosaTensorGen:
178 """Tensor generators create a shape list for the placeholder and const tensor
179 data operands for the operator. The actual random data is generated separately for each test."""
180
Eric Kunzee5e26762020-10-13 16:11:07 -0700181 def __init__(self):
182 pass
183
184 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100185 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800186 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700187 shape = testGen.makeShape(rank)
188
Matthew Haddonc2025212021-10-08 21:21:05 +0100189 # Constrict dimension size for large ranks when creating WrongRank tests
190 shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 shape_list = []
193 for i in range(pl + const):
194 shape_list.append(shape.copy())
195
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100196 if error_name == ErrorIf.RankMismatch:
197 if rank == 1 and i != 1:
198 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
199 elif i != 1:
200 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
201
Eric Kunzee5e26762020-10-13 16:11:07 -0700202 return shape_list
203
204 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100205 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800206 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700207
Matthew Haddon848efb42021-09-09 12:30:53 +0100208 if error_name != ErrorIf.WrongRank:
209 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700210
211 shape = testGen.makeShape(rank)
212
213 # Constrict the batch size?
214 if testGen.args.max_batch_size:
215 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100216
217 # Constrict dimension size for large ranks when creating WrongRank tests
218 shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700219
220 shape_list = []
221 for i in range(pl + const):
222 shape_list.append(shape.copy())
223
224 return shape_list
225
226 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100227 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800228 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800229
Kevin Cheng550ccc52021-03-03 11:21:43 -0800230 assert pl == 2
231 assert const == 0
232 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800233
234 values_in_shape = testGen.makeShape(rank)
235
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100236 # ignore max batch size if target shape is set
237 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800238 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
239
Kevin Cheng550ccc52021-03-03 11:21:43 -0800240 W = testGen.randInt(
241 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
242 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100243 # Constrict W if one dimension is too large to keep tensor size reasonable
244 if max(values_in_shape) > 5000:
245 W = testGen.randInt(0, 16)
246
Kevin Cheng77d0f762020-11-24 10:26:32 -0800247 input_shape = [values_in_shape[0], W, values_in_shape[2]]
248
249 shape_list = []
250 shape_list.append(values_in_shape.copy())
251 shape_list.append(input_shape.copy())
252
253 return shape_list
254
255 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100256 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700257 shape = testGen.makeShape(rank)
258
Kevin Cheng550ccc52021-03-03 11:21:43 -0800259 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700260
261 shape_list = []
262
263 # Choose one of the inputs to broadcast
264 bcast_idx = testGen.randInt(0, pl + const)
265 for i in range(pl + const):
266 shape_bcast = shape.copy()
267
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100268 if error_name == ErrorIf.RankMismatch:
269 bcast_idx = -1 # Turn off broadcast because we are not testing it
270 if rank == 1 and i != 1:
271 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
272 elif i != 1:
273 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
274
Eric Kunzee5e26762020-10-13 16:11:07 -0700275 # If the chosen input, pick a random index to broadcast
276 if i == bcast_idx:
277 fuzz_idx = testGen.randInt(0, rank)
278 shape_bcast[fuzz_idx] = 1
279
280 shape_list.append(shape_bcast)
281
282 return shape_list
283
284 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100285 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800286 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700287
Kevin Cheng550ccc52021-03-03 11:21:43 -0800288 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700289
290 # IFM dimensions are NHWC
291 ifm_shape = testGen.makeShape(rank)
292
293 # Constrict the batch size?
294 if testGen.args.max_batch_size:
295 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
296
297 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800298 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700299
300 # Generate a random OFM depth
301 ofm_depth = testGen.makeShape(1)[0]
302
303 # The filter dimensions are OHWI
304 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
305
306 # The bias is OC
307 bias_shape = np.asarray([ofm_depth])
308
309 return [ifm_shape, filter_shape, bias_shape]
310
311 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100312 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700313 pl, const = op["operands"]
314
315 assert rank == 5
316
317 # IFM dimensions are NDHWC
318 ifm_shape = testGen.makeShape(rank)
319
320 # Constrict the batch size?
321 if testGen.args.max_batch_size:
322 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
323
324 # Get the filter depth/height/width from the operator parameters
325 filter_dhw = op["filter"]
326
327 # Generate a random OFM channel
328 ofm_channel = testGen.makeShape(1)[0]
329
330 # The filter dimensions are ODHWI
331 filter_shape = np.asarray(
332 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
333 )
334
335 # The bias is OC
336 bias_shape = np.asarray([ofm_channel])
337
338 return [ifm_shape, filter_shape, bias_shape]
339
340 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100341 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800342 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700343
Kevin Cheng550ccc52021-03-03 11:21:43 -0800344 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700345
346 # IFM dimensions are NHWC
347 ifm_shape = testGen.makeShape(rank)
348
349 # Constrict the batch size?
350 if testGen.args.max_batch_size:
351 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
352
353 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800354 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700355
356 # Generate a random OFM depth
357 ofm_depth = testGen.makeShape(1)[0]
358
359 # The filter dimensions are OHWI
360 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
361
Kevin Cheng989cb052021-04-28 16:29:44 -0700362 # The bias is OC
363 bias_shape = np.asarray([ofm_depth])
364
365 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700366
367 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100368 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800369 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700370
Kevin Cheng550ccc52021-03-03 11:21:43 -0800371 assert rank == 4
372 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700373
374 # IFM dimensions are NHWC
375 ifm_shape = testGen.makeShape(rank)
376
377 # Constrict the batch size?
378 if testGen.args.max_batch_size:
379 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
380
381 # Get the filter height/width from the operator parameters
382 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800383 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700384
385 # Generate a random OFM depth, but don't let it get too big because
386 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800387 filter_m = (
388 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
389 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700390
391 # The filter dimensions are HWCM
392 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
393
394 # The bias is M * C
395 bias_shape = np.asarray([ifm_shape[3] * filter_m])
396
397 return [ifm_shape, filter_shape, bias_shape]
398
399 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100400 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800401 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700402
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100403 if error_name != ErrorIf.WrongRank:
404 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700405
406 input_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100407
408 # Constrict dimension size for large ranks when creating WrongRank tests
409 shape = TosaErrorIfArgGen.eiRestrictDimension(input_shape, error_name)
410
Kevin Chengacb550f2021-06-29 15:32:19 -0700411 filter_oc = testGen.rng.integers(
412 low=testGen.args.tensor_shape_range[0],
413 high=testGen.args.tensor_shape_range[1],
414 size=1,
415 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700416 filter_shape = np.asarray([filter_oc, input_shape[1]])
417
418 bias_shape = np.asarray([filter_oc])
419
420 return [input_shape, filter_shape, bias_shape]
421
422 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100423 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800424 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700425
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100426 if error_name != ErrorIf.WrongRank:
427 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800428 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700429
430 a_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100431
432 # Constrict dimension size for large ranks when creating WrongRank tests
433 shape = TosaErrorIfArgGen.eiRestrictDimension(a_shape, error_name)
434
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100435 # Get a random number for b_oc even if target shape is defined
436 b_oc = np.int32(
437 testGen.rng.integers(
438 low=testGen.args.tensor_shape_range[0],
439 high=testGen.args.tensor_shape_range[1],
440 size=1,
441 )
442 )[0]
443 # If N or H is large let b_oc be 1 to reduce output tensor size
444 if max(a_shape) > 1000:
445 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700446
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100447 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700448 return [a_shape, b_shape]
449
Matthew Haddon818ab902021-07-27 09:12:49 +0100450 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100451 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100452 pl, const = opName["operands"]
453 shape = testGen.makeShape(rank)
454
455 # Create extra tensors to concat.
456 # Take into account value of pl when getting maximum number of concats
457 num_tensors = testGen.randInt(0, 4)
458 shape_list = []
459 for i in range(pl + const + num_tensors):
460 shape_list.append(shape.copy())
461
462 return shape_list
463
464 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100465 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100466 # Split concat shape along axis to allow for multiple const inputs
467 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100468 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddon818ab902021-07-27 09:12:49 +0100469 return shapeList
470
Jeremy Johnson960985a2021-10-06 10:58:14 +0100471 # Create copy of shape we are going to split (so we don't alter shapeList)
472 shape = shapeList[0].copy()
473 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100474 new_shapeList = [shape.copy()]
475 length_on_axis = shape[axis]
476 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700477 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100478 # Calculate split on axis and remaining value
479 split_shape_val = int(shape[axis] / 2)
480 remaining_length = remaining_length - split_shape_val
481
482 # Append new shape, and set remaining shape
483 shape[axis] = split_shape_val
484 new_shapeList.append(shape.copy())
485 shape[axis] = remaining_length
486 if i == len(shapeList) - 3:
487 new_shapeList.append(shape.copy())
488
489 return new_shapeList
490
491
Eric Kunzee5e26762020-10-13 16:11:07 -0700492class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800493 """Argument generators create exhaustive or random lists of attributes for operators that take
494 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
495 tuples where the descriptive_name is appended to the test name and the arglist is expanded
496 as arguments to the operator build function."""
497
Eric Kunzee5e26762020-10-13 16:11:07 -0700498 def __init__(self):
499 pass
500
501 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100502 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800503 """A trivial argument generator for operators that don't take any
504 non-tensor arguments"""
505 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700506
507 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100508 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800509 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700510 axes = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700511 shape = shapeList[0]
512
Matthew Haddond6ce7252021-09-29 15:35:44 +0100513 if error_name == ErrorIf.AxisSmallerZero:
514 small_axis = testGen.rng.integers(-5, 0)
515 axes.append(("axis{}".format(small_axis), [small_axis]))
516 elif error_name == ErrorIf.AxisLargerRank:
517 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
518 axes.append(("axis{}".format(large_axis), [large_axis]))
519 else:
520 for a in range(0, len(shape)):
521 axes.append(("axis{}".format(a), [a]))
522
Eric Kunzee5e26762020-10-13 16:11:07 -0700523 return axes
524
525 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100526 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700527 arg_list = []
528
529 ifm_shape = shapeList[0]
530 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100531 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
532 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700533
Les Bell7aa69f42021-09-20 10:44:07 +0100534 # Check the rank
535 rank = 5 if opName.startswith("conv3d") else 4
536 assert len(ifm_shape) == rank
537 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700538
Les Bell7aa69f42021-09-20 10:44:07 +0100539 # kernel rank omits batch and channels
540 k_rank = rank - 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700541
Les Bell7aa69f42021-09-20 10:44:07 +0100542 # Generate comprehensive argument lists
543 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
544 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
545 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
546 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
547 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
548 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700549
Les Bell7aa69f42021-09-20 10:44:07 +0100550 # add some oversize argument values
551 if max(ifm_shape) < 64:
552 bigPadding = 9
553 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
554 bigStride = 8
555 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
556 bigDilation = 7
557 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100558
559 # There are too many parameter combinations, so generate them sparsely
Les Bell7aa69f42021-09-20 10:44:07 +0100560 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
561 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
562 if sparsity < 13:
563 sparsity = 1
564 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
565 sparsity += 1
Les Bellf414b3c2021-09-06 11:29:46 +0100566 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100567 for s in sorted(list(strides)):
568 for p in sorted(list(paddings)):
569 for d in sorted(list(dilations)):
570 if (n % sparsity == 0
571 # padding must not exceed the kernel size ?
572 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
573 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
574 # the padded shape must exceed the kernel size
575 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
576 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
577 # the padded shape must exceed the dilation
578 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
579 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
580 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100581 arg_list.append(
582 (
583 "st{}_pad{}_dilat{}".format(
584 "".join([str(x) for x in s]),
585 "".join([str(x) for x in p]),
586 "".join([str(x) for x in d]),
587 ),
588 [s, p, d],
589 )
590 )
591 n += 1
592
Kevin Cheng1533b852021-09-01 12:51:58 -0700593 return arg_list
594
595 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100596 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700597 arg_list = []
598
599 ifm_shape = shapeList[0]
600 filter_shape = shapeList[1]
601
602 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800603 assert len(ifm_shape) == 4
604 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700605
Les Bell7aa69f42021-09-20 10:44:07 +0100606 # Generate comprehensive argument lists
607 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
608 paddings = {x for x in itertools.product(*([p_vals] * 2))}
609 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
610 strides = {x for x in itertools.product(*([s_vals] * 2))}
611 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
612 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700613
Les Bell7aa69f42021-09-20 10:44:07 +0100614 # add some oversize argument values
615 if max(ifm_shape) < 64:
616 bigPadding = 9
617 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
618 bigStride = 8
619 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
620 bigDilation = 7
621 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700622
Les Bell7aa69f42021-09-20 10:44:07 +0100623 # There are too many parameter combinations, so generate them sparsely
624 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
625 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
626 if sparsity < 13:
627 sparsity = 1
628 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
629 sparsity += 1
630 n = 0
631 for s in sorted(list(strides)):
632 for p in sorted(list(paddings)):
633 for d in sorted(list(dilations)):
634 if n % sparsity == 0:
635 # Determine the output shape
636 oh = (
637 ifm_shape[1]
638 - filter_shape[1]
639 - (filter_shape[1] - 1) * (d[0] - 1)
640 + 2 * p[0]
641 ) // s[0] + 1
642 ow = (
643 ifm_shape[2]
644 - filter_shape[2]
645 - (filter_shape[2] - 1) * (d[1] - 1)
646 + 2 * p[1]
647 ) // s[1] + 1
648 os = [ifm_shape[0], oh, ow, filter_shape[0]]
649 arg_list.append(
650 (
651 "st{}_pad{}_dilat{}_os{}".format(
652 "".join([str(x) for x in s]),
653 "".join([str(x) for x in p]),
654 "".join([str(x) for x in d]),
655 "x".join([str(x) for x in os]),
656 ),
657 [s, p, d, os],
658 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800659 )
Les Bell7aa69f42021-09-20 10:44:07 +0100660 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700661
662 return arg_list
663
664 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100665 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700666 arg_list = []
667 rank = len(shapeList[0])
668
Les Bell7ffccce2021-07-28 15:37:02 +0100669 # Exhaustively test combinations of padding on each side of each dimension
670 # - the range of padding values is defined by pad_min and pad_max
671 # - for padding >9, the name format needs to be more distinctive
672 pad_min, pad_max = 0, 1
673 pad_values = [x for x in range(pad_min, pad_max + 1)]
Matthew Haddone807aae2021-10-11 18:12:58 +0100674 if error_name == ErrorIf.PadSmallerZero:
675 pad_values = [x for x in range(-2, 0)]
Les Bell7ffccce2021-07-28 15:37:02 +0100676 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
677 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700678
Kevin Chengfe392ce2021-10-18 21:51:55 +0000679 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
680 pad_const_int = testGen.getRandNumberDType(dtype)
681 pad_const_fp = 0
682 elif dtype == DType.FLOAT:
683 pad_const_int = 0
684 pad_const_fp = testGen.getRandNumberDType(dtype)
685 else:
686 return []
687
Les Bell7ffccce2021-07-28 15:37:02 +0100688 for paddings in shape_pad_values:
689 name = "pad"
690 for r in range(rank):
691 before, after = paddings[r]
692 name = f"{name}{before}{after}"
Kevin Chengfe392ce2021-10-18 21:51:55 +0000693 arg_list.append((name, [np.array(paddings), pad_const_int, pad_const_fp]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700694
695 return arg_list
696
697 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100698 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700699 arg_list = []
700
701 shape = shapeList[0]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100702 if error_name != ErrorIf.WrongRank:
703 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700704
Les Bell7aa69f42021-09-20 10:44:07 +0100705 # Generate comprehensive argument lists
706 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
707 paddings = {x for x in itertools.product(*([p_vals] * 4))}
708 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
709 strides = {x for x in itertools.product(*([s_vals] * 2))}
710 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
711 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700712
Les Bell7aa69f42021-09-20 10:44:07 +0100713 # add some oversize argument values
714 bigStride = 7
715 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
716 bigKernel = 6
717 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
718 if max(shape) < 64:
719 # padding must be less than the kernel size
720 bigPadding = bigKernel - 1
721 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700722
Les Bell7aa69f42021-09-20 10:44:07 +0100723 # There are too many parameter combinations, so generate them sparsely
724 sparsity = len(paddings) * len(strides) * len(kernels) // 500 + 1
725 n = 0
726 for s in sorted(list(strides)):
727 for p in sorted(list(paddings)):
728 for k in sorted(list(kernels)):
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100729 if error_name in [ErrorIf.StrideSmallerOne, ErrorIf.KernelSmallerOne, ErrorIf.PadSmallerZero, ErrorIf.PadLargerEqualKernel]:
730 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(testGen, error_name, s, p, k)
731 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
732 arg_list.append(
733 (
734 "st{}_kern{}_pad{}".format(
735 "".join([str(x) for x in sNew]),
736 "".join([str(x) for x in kNew]),
737 "".join([str(x) for x in pNew]),
738 ),
739 [sNew, pNew, kNew],
740 )
741 )
742 elif (n % sparsity == 0
Les Bell7aa69f42021-09-20 10:44:07 +0100743 # padding must not exceed the kernel size
744 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
745 # the padded shape must exceed the kernel size
746 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
747 ):
748 arg_list.append(
749 (
750 "st{}_kern{}_pad{}".format(
751 "".join([str(x) for x in s]),
752 "".join([str(x) for x in k]),
753 "".join([str(x) for x in p]),
754 ),
755 [s, p, k],
756 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800757 )
Les Bell7aa69f42021-09-20 10:44:07 +0100758 n += 1
759
Eric Kunzee5e26762020-10-13 16:11:07 -0700760 return arg_list
761
762 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100763 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700764 arg_list = []
765
766 # Enumerate the output types here
767 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800768 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700769 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800770 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700771 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800772 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700773 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800774 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700775 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800776 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700777 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800778 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700779
780 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800781 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700782
783 return arg_list
784
785 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100786 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700787 arg_list = []
788
789 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100790 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100791 if dtype in [DType.UINT8, DType.INT8] and error_name == ErrorIf.OutputZeroPointNotZero:
792 continue
793 if inDtype == DType.UINT8 and dtype != DType.INT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100794 # The only output dtype for UINT8 is INT8, skip all other combinations
795 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100796 if inDtype != DType.INT8 and dtype == DType.UINT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100797 # The only input dtype for UINT8 is INT8, skip all other combinations
798 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100799 if error_name == ErrorIf.WrongOutputType and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype):
800 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100801
Kevin Cheng550ccc52021-03-03 11:21:43 -0800802 for scale32 in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100803 if error_name == ErrorIf.ScaleTrue and scale32 == False:
804 continue
805 elif error_name == ErrorIf.ScaleNotTrue and scale32 == True:
806 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800807 for double_round in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100808 if error_name == ErrorIf.ScaleNotTrue and double_round == False:
809 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800810 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700811
Matthew Haddonc2025212021-10-08 21:21:05 +0100812 if inDtype == DType.INT48 and scale32 and error_name != ErrorIf.ScaleTrue:
Eric Kunzee5e26762020-10-13 16:11:07 -0700813 # Illegal condition. Must be scale32=False
814 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100815 if double_round and not scale32 and error_name != ErrorIf.ScaleNotTrue:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100816 # Illegal condition. ERROR_IF(!scale32 && double_round)
817 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700818
Kevin Cheng550ccc52021-03-03 11:21:43 -0800819 arg_list.append(
820 (
821 "out{}_sc{}_dr{}_pc{}".format(
822 DTypeNames[dtype],
823 int(scale32),
824 int(double_round),
825 int(per_channel),
826 ),
827 [dtype, scale32, double_round, per_channel],
828 )
829 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700830
831 return arg_list
832
Kevin Chengaee1fac2020-11-11 13:54:06 -0800833 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100834 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800835 arg_list = []
836
837 if dtype is DType.INT32:
838 for p in range(testGen.args.num_rand_permutations):
839
840 shift = testGen.randInt(0, 32)
841
Kevin Cheng550ccc52021-03-03 11:21:43 -0800842 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800843 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100844 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800845
846 return arg_list
847
848 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100849 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800850 arg_list = []
851
Kevin Cheng550ccc52021-03-03 11:21:43 -0800852 arg_list.append(("roundTrue", [True]))
853 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800854
855 return arg_list
856
Eric Kunzee5e26762020-10-13 16:11:07 -0700857 # Helper function for reshape. Gets some factors of a larger number.
858 @staticmethod
859 def getFactors(val, start=1):
860 factors = []
861
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100862 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700863 if (val % i) == 0:
864 factors.append(i)
865
866 return factors
867
868 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100869 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700870 arg_list = []
871
872 origShape = shapeList[0]
873
874 totalElements = 1
875 for s in origShape:
876 totalElements *= s
877
878 # This code is NOT fast. Fortunately, the numbers are fairly small.
879 factors = TosaArgGen.getFactors(totalElements)
880
881 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100882 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800883 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700884 continue
885
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100886 found = True
887 # escape_counter breaks while loop if it continues on for too long
888 escape_counter = 0
889 while found:
890 newShape = []
891 # Generate newShape ensuring it isn't a duplicate
892 remainingElements = totalElements
893 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100894 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100895 # pick rank-1 factors
896 newShape.append(shuffledFactors[0])
897 remainingElements = remainingElements // shuffledFactors[0]
898 shuffledFactors = testGen.rng.permutation(
899 TosaArgGen.getFactors(remainingElements)
900 )
901 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700902
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100903 # Toss in a -1 sometimes
904 minusOne = testGen.randInt(0, newRank * 4)
905 if minusOne < newRank:
906 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700907
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100908 # Check for duplicates
909 found = False
910 for name, other_shape in arg_list:
911 if other_shape[0] == newShape:
912 found = True
913 break
914
915 escape_counter += 1
916 if escape_counter >= 100:
917 break
918
919 if not found:
920 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700921
922 return arg_list
923
Eric Kunzee5e26762020-10-13 16:11:07 -0700924 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100925 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700926 arg_list = []
927
928 ifm_shape = shapeList[0]
929
Matthew Haddone807aae2021-10-11 18:12:58 +0100930
931 if error_name == ErrorIf.IndexOutsideBounds:
932 incorrect_large_index = range(len(ifm_shape)+1, 2*len(ifm_shape)+1)
933 incorrect_small_index = range(-len(ifm_shape), 0)
934 permutations = [p for p in itertools.permutations(incorrect_large_index)]
935 permutations.extend([p for p in itertools.permutations(incorrect_small_index)])
936 elif error_name == ErrorIf.IndexUsedTwice:
937 # Create list with a duplicated index
938 perm_range = list(range(len(ifm_shape)))
939 index_choice = testGen.rng.choice(range(len(perm_range)))
940 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
941 permutations = [p for p in itertools.permutations(perm_range)]
942
943
944 else:
945 # Get all permutations
946 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700947
Jeremy Johnsona6185572021-06-21 15:55:35 +0100948 # Limit to possible permutations from shape dimension or argument setting
949 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700950
Jeremy Johnsona6185572021-06-21 15:55:35 +0100951 # Get random permutation generator that uses all permutations
952 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700953
Jeremy Johnsona6185572021-06-21 15:55:35 +0100954 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -0700955 arg_list = [
956 ("perm{}".format(p), [random_permutations[p].tolist()])
957 for p in range(limit)
958 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700959 return arg_list
960
961 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100962 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700963 arg_list = []
964
965 ifm_shape = shapeList[0]
966 rank = len(ifm_shape)
967
968 for p in range(testGen.args.num_rand_permutations):
Matthew Haddone807aae2021-10-11 18:12:58 +0100969 start = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700970 size = []
971
Kevin Cheng550ccc52021-03-03 11:21:43 -0800972 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700973
974 for i in range(rank):
975 if ifm_shape[i] > 1:
Matthew Haddone807aae2021-10-11 18:12:58 +0100976 start.append(testGen.randInt(0, ifm_shape[i]))
977 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700978
979 # Invalid slice size?
980 if size[i] == 0:
981 valid = False
982 else:
Matthew Haddone807aae2021-10-11 18:12:58 +0100983 start.append(0)
Eric Kunzee5e26762020-10-13 16:11:07 -0700984 size.append(1)
985
986 if valid:
Matthew Haddone807aae2021-10-11 18:12:58 +0100987 # If ERROR_IF test required then incorrect start, size will be returned
988 start, size = TosaErrorIfArgGen.eiSliceErrorIf(testGen, error_name, ifm_shape, start, size)
989 arg_list.append(("perm{}".format(p), [start, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700990 return arg_list
991
992 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100993 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700994 arg_list = []
995
996 ifm_shape = shapeList[0]
997 rank = len(ifm_shape)
998
999 for p in range(testGen.args.num_rand_permutations):
1000
1001 # Pick a few random, but small multiple values
1002 # because otherwise this has a tendency to generate
1003 # enormous tensors
1004 multiples = []
1005 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +01001006 if ifm_shape[i] > 1000:
1007 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
1008 multiples.append(1)
1009 elif max(ifm_shape) > 1000:
1010 multiples.append(2)
1011 else:
1012 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001013 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001014
1015 return arg_list
1016
1017 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001018 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001019 arg_list = []
1020
1021 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +01001022 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001023
1024 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +01001025 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001026 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +01001027 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001028 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +01001029 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001030 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +01001031 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001032 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001033 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001034 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +01001035 elif error_name == ErrorIf.WrongInputType:
1036 # If an incorrect input type is used then we set a 'correct'
1037 # output type to avoid other errors
1038 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -07001039 else:
1040 continue
1041
1042 for outputDType in outputDTypeList:
1043 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -07001044 # Randomly generate legal output dimensions and shift
1045 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +01001046 # A output_dim of 1 will cause offset to exceed allowed range
1047 # so minimum value 2 produced below
1048 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
1049 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
1050 output_dims[0] += 1
1051 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
1052 output_dims[1] += 1
1053
Kevin Cheng77d0f762020-11-24 10:26:32 -08001054 in_center_h = (ifm_shape[1] - 1) / 2.0
1055 in_center_w = (ifm_shape[2] - 1) / 2.0
1056 out_center_h = (output_dims[0] - 1) / 2.0
1057 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -07001058
Kevin Cheng77d0f762020-11-24 10:26:32 -08001059 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
1060 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
1061 fp_offset_y = in_center_h - fp_stride_y * out_center_h
1062 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -07001063
Kevin Cheng77d0f762020-11-24 10:26:32 -08001064 if outputDType == DType.FLOAT:
1065 shift = 0
1066 stride = [0, 0]
1067 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001068 stride_fp = [fp_stride_y, fp_stride_x]
1069 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +01001070
1071 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001072 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001073 testGen,
1074 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001075 mode,
1076 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001077 shapeList,
1078 outputDType,
1079 shift,
1080 stride,
1081 stride_fp,
1082 offset,
1083 offset_fp
1084 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001085 else:
1086 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001087
Kevin Cheng550ccc52021-03-03 11:21:43 -08001088 arg_list.append(
1089 (
1090 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001091 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001092 output_dims[0],
1093 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001094 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001095 stride_fp[0],
1096 stride_fp[1],
1097 offset_fp[0],
1098 offset_fp[1],
1099 ),
1100 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001101 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001102 stride,
1103 offset,
1104 shift,
1105 stride_fp,
1106 offset_fp,
1107 output_dims,
1108 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001109 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001110 ],
1111 )
1112 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001113 else:
1114 shift = 11
1115 unit = float(1 << shift)
1116 stride_y = int(round(fp_stride_y * unit))
1117 stride_x = int(round(fp_stride_x * unit))
1118 offset_y = int(round(fp_offset_y * unit))
1119 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001120
Kevin Cheng550ccc52021-03-03 11:21:43 -08001121 while (
Matthew Haddone86fd342021-09-07 16:12:21 +01001122 stride_y >= (16 << shift)
1123 or stride_x >= (16 << shift)
1124 or offset_y >= (16 << shift)
1125 or offset_x >= (16 << shift)
1126 or offset_y <= (-16 << shift)
1127 or offset_x <= (-16 << shift)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001128 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001129 shift = shift - 1
1130 unit = float(1 << shift)
1131 stride_y = int(round(fp_stride_y * unit))
1132 stride_x = int(round(fp_stride_x * unit))
1133 offset_y = int(round(fp_offset_y * unit))
1134 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001135
Kevin Cheng550ccc52021-03-03 11:21:43 -08001136 stride = [stride_y, stride_x]
1137 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001138
1139 stride_fp = [0.0, 0.0]
1140 offset_fp = [0.0, 0.0]
1141
Matthew Haddone86fd342021-09-07 16:12:21 +01001142 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001143 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001144 testGen,
1145 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001146 mode,
1147 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001148 shapeList,
1149 outputDType,
1150 shift,
1151 stride,
1152 stride_fp,
1153 offset,
1154 offset_fp
1155 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001156 else:
1157 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001158
Kevin Cheng550ccc52021-03-03 11:21:43 -08001159 arg_list.append(
1160 (
1161 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001162 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001163 shift,
1164 output_dims[0],
1165 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001166 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001167 stride[0],
1168 stride[1],
1169 offset[0],
1170 offset[1],
1171 ),
1172 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001173 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001174 stride,
1175 offset,
1176 shift,
1177 stride_fp,
1178 offset_fp,
1179 output_dims,
1180 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001181 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001182 ],
1183 )
1184 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001185
1186 return arg_list
1187
Kevin Chengfe392ce2021-10-18 21:51:55 +00001188 @staticmethod
1189 def agTable(testGen, opName, shapeList, dtype, error_name=None):
1190 arg_list = []
1191
1192 if dtype == DType.INT8:
1193 table = np.int32(
1194 testGen.rng.integers(low=-128, high=128, size=[256])
1195 ).tolist()
1196 else: # INT16
1197 table = np.int32(
1198 testGen.rng.integers(low=-32768, high=32768, size=[513])
1199 ).tolist()
1200
1201 arg_list.append(
1202 (
1203 "",
1204 [table],
1205 )
1206 )
1207 return arg_list
1208
Matthew Haddon1c00b712021-10-01 15:51:03 +01001209 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001210 # CondIf generates the condition values here.
1211 # Convert to tensors in the build function, along with the
1212 # then and else blocks
1213 arg_list = []
1214
1215 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001216 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001217
1218 return arg_list
1219
Matthew Haddon1c00b712021-10-01 15:51:03 +01001220 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001221 # While loop: 0 iterations, 1, more than 1
1222 arg_list = []
1223
1224 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001225 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001226
1227 return arg_list
1228
Matthew Haddone86fd342021-09-07 16:12:21 +01001229class TosaErrorIfArgGen:
1230
1231 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001232 def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
Matthew Haddone86fd342021-09-07 16:12:21 +01001233
1234 if outputDType == DType.FLOAT:
1235 if error_name == ErrorIf.StrideSmallerEqualZero:
1236 stride_fp = testGen.rng.random(size=[2]) - 2
1237 elif error_name == ErrorIf.ShiftNotZero:
1238 shift = testGen.rng.integers(1, 5)
1239 elif error_name == ErrorIf.StrideLargerDimension:
1240 shape = shapeList[0]
1241 transform_height = testGen.rng.choice([False, True])
1242 if transform_height:
1243 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1244 else:
1245 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1246 else:
1247 if error_name == ErrorIf.StrideSmallerEqualZero:
1248 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1249 elif error_name == ErrorIf.ShiftSmallerOne:
1250 shift = testGen.rng.integers(-3, 1)
1251 if shift <= 0:
1252 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1253 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1254 else:
1255 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1256 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1257 elif error_name == ErrorIf.ShiftLargerEleven:
1258 shift = np.int16(testGen.rng.integers(12, 15))
1259 elif error_name == ErrorIf.StrideLargerDimension:
1260 shape = shapeList[0]
1261 transform_height = testGen.rng.choice([False, True])
1262 if transform_height:
1263 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1264 else:
1265 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1266 elif error_name == ErrorIf.StrideLargerEqualMax:
1267 stride = [(16 << shift) + 1, (16 << shift) + 1]
1268 elif error_name == ErrorIf.OffsetLargerEqualMax:
1269 offset = [(16 << shift) + 1, (16 << shift) + 1]
1270 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1271 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1272
Matthew Haddon1c00b712021-10-01 15:51:03 +01001273
Matthew Haddon848efb42021-09-09 12:30:53 +01001274 if error_name == ErrorIf.WrongOutputType:
1275 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1276 incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1277 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1278 incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
1279 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1280 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
1281 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1282 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
1283 elif dtype == DType.FLOAT:
1284 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
1285 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001286
Matthew Haddon848efb42021-09-09 12:30:53 +01001287 return shift, stride, stride_fp, offset, offset_fp, outputDType
1288
Matthew Haddone807aae2021-10-11 18:12:58 +01001289
Matthew Haddon848efb42021-09-09 12:30:53 +01001290 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001291 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
1292 if (error_name == ErrorIf.StrideSmallerOne
1293 # padding must not exceed the kernel size
1294 and pad[0] < kernel[0] and pad[1] < kernel[0] and pad[2] < kernel[1] and pad[3] < kernel[1]):
1295 wrongStride = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1296 return wrongStride, pad, kernel
1297 elif error_name == ErrorIf.PadSmallerZero:
1298 wrongPad = (testGen.rng.choice([-1, -2, -3]),
1299 testGen.rng.choice([-1, -2, -3]),
1300 testGen.rng.choice([-1, -2, -3]),
1301 testGen.rng.choice([-1, -2, -3]))
1302 return stride, wrongPad, kernel
1303 elif error_name == ErrorIf.KernelSmallerOne:
1304 wrongKernel = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1305 return stride, pad, wrongKernel
1306 elif error_name == ErrorIf.PadLargerEqualKernel:
1307 wrongPad = (testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1308 testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1309 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]),
1310 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]))
1311 return stride, wrongPad, kernel
1312 else:
1313 return None, None, None
1314
Matthew Haddone807aae2021-10-11 18:12:58 +01001315
Matthew Haddonc2025212021-10-08 21:21:05 +01001316 @staticmethod
1317 def eiRescaleWrongOutputType(input_dtype, output_dtype):
1318 if input_dtype == DType.INT8:
1319 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1320 return True
1321 if input_dtype in [DType.INT16, DType.INT32]:
1322 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1323 return True
1324 elif input_dtype == DType.INT48:
1325 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1326 return True
1327 elif input_dtype == DType.UINT8:
1328 if output_dtype != DType.INT8:
1329 return True
1330 return False
1331
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001332
1333 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001334 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1335 # Mess up input/output tensors for ERROR_IF checks
1336 if error_name == "WrongInputList":
1337 add_input = testGen.rng.choice([True, False])
1338 if add_input:
1339 input_list.append('eiDummyInput')
1340 else:
1341 input_list = input_list[:-1]
1342 if error_name == "WrongOutputList":
1343 add_output = testGen.rng.choice([True, False])
1344 if add_output:
1345 output_list.append('eiDummyOutput')
1346 else:
1347 output_list = []
1348 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001349
Matthew Haddone807aae2021-10-11 18:12:58 +01001350
Matthew Haddonc2025212021-10-08 21:21:05 +01001351 @staticmethod
1352 def eiRestrictDimension(shape, error_name):
1353 # Restrict dimension size if rank is large for WrongRank Error_If
1354 # This will keep the test sizes reasonably small
1355 if error_name == ErrorIf.WrongRank:
1356 if len(shape) > 4:
1357 shape[4] = 1
1358
1359 return shape
1360
Matthew Haddone807aae2021-10-11 18:12:58 +01001361
1362 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
1363 if error_name == ErrorIf.StartSmallerZero:
1364 newStart = []
1365 for i in range(len(input_shape)):
1366 newStart.append(testGen.rng.choice([-3, -2, -1]))
1367 return newStart, size
1368 elif error_name == ErrorIf.SizeSmallerEqualZero:
1369 newSize = []
1370 for i in range(len(input_shape)):
1371 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
1372 return start, newSize
1373 elif error_name == ErrorIf.StartSizeOutsideBounds:
1374 newStart, newSize = [], []
1375 for i in range(len(input_shape)):
1376 newStart.append(input_shape[i]-1)
1377 newSize.append(testGen.rng.choice([2, 3, 4]))
1378 return newStart, newSize
1379 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
1380 remove = testGen.rng.choice([True, False])
1381 if remove:
1382 newStart = start[1:]
1383 newSize = size[1:]
1384 else:
1385 newStart = start
1386 newStart.append(1)
1387 newSize = size
1388 newSize.append(1)
1389 return newStart, newSize
1390 else:
1391 return start, size
1392
Matthew Haddone86fd342021-09-07 16:12:21 +01001393class TosaErrorValidator:
1394
Matthew Haddon848efb42021-09-09 12:30:53 +01001395 @staticmethod
1396 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
1397 # Check ERROR_IF statements
1398
1399 for val_fcn in validator_fcns:
1400 val_result = val_fcn(True, **kwargs)
1401
1402 validator_name = val_result['error_name']
1403 error_result = val_result['error_result']
1404 error_reason = val_result['error_reason']
1405
1406 if error_result:
1407 if error_name == validator_name:
1408 serializer.setExpectedReturnCode(2, error_reason)
1409 else:
1410 print(f"Multiple ERROR_IF checks hit \nError required: {error_name}, Error_produced: {validator_name}")
1411 return None # Return None to delete test if wrong ERROR_IF is hit
1412 else:
1413 if error_name == validator_name:
1414 print(f"No ERROR_IF hit for {error_name}")
1415 return None
1416
1417 @staticmethod
1418 def evWrongInputType(check=False, **kwargs):
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001419 all_dtypes = {DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
Matthew Haddon848efb42021-09-09 12:30:53 +01001420
1421 # Find the unsupported input data types
1422 assert 'op' in kwargs
1423 op = kwargs['op']
1424 input_dtypes = op['types']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001425
1426 allowed_input_dtypes = {t[0] if isinstance(t, list) else t for t in input_dtypes}
1427 wrong_input_dtypes = list(all_dtypes - allowed_input_dtypes)
Matthew Haddon848efb42021-09-09 12:30:53 +01001428
1429 error_name = ErrorIf.WrongInputType
1430 param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
1431 error_result = False
1432 error_reason = "Input data type not supported for this operator"
1433
1434 if check:
1435 input_dtype = kwargs['input_dtype']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001436 if op['op'] == Op.FULLY_CONNECTED:
1437 if input_dtype not in allowed_input_dtypes:
1438 error_result = True
1439 elif input_dtype not in input_dtypes:
Matthew Haddon848efb42021-09-09 12:30:53 +01001440 error_result = True
1441
1442 info_dict = {
1443 "error_name": error_name,
1444 "error_result": error_result,
1445 "error_reason": error_reason,
1446 "param_reqs": param_reqs
1447 }
1448 return info_dict
1449
1450 @staticmethod
1451 def evWrongOutputType(check=False, **kwargs):
1452 error_name = ErrorIf.WrongOutputType
1453 param_reqs = {"rank": None, "dtype": None, "shape": None}
1454 error_result = False
1455 error_reason = "Output data type not supported for this configuration of operator"
1456
1457 if check:
1458 input_dtype = kwargs['input_dtype']
1459 output_dtype = kwargs['output_dtype']
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001460 op = kwargs['op']
Matthew Haddon848efb42021-09-09 12:30:53 +01001461
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001462 if op['op'] == Op.RESIZE:
1463 mode = kwargs['mode']
1464 if (
1465 (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
1466 (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
1467 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1468 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1469 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1470 ):
1471 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001472 elif op['op'] == Op.RESCALE:
1473 if input_dtype == DType.INT8:
1474 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1475 error_result = True
1476 if input_dtype in [DType.INT16, DType.INT32]:
1477 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1478 error_result = True
1479 elif input_dtype == DType.INT48:
1480 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1481 error_result = True
1482 elif input_dtype == DType.UINT8:
1483 if output_dtype != DType.INT8:
1484 error_result = True
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001485 elif op['op'] in [Op.FULLY_CONNECTED, Op.MATMUL]:
1486 if (
1487 (input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1488 (input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1489 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1490 ):
1491 error_result = True
1492 elif op['op'] == Op.ARGMAX:
1493 if input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] and output_dtype != DType.INT32:
1494 error_result = True
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001495 else:
1496 if output_dtype != input_dtype:
1497 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001498
1499 info_dict = {
1500 "error_name": error_name,
1501 "error_result": error_result,
1502 "error_reason": error_reason,
1503 "param_reqs": param_reqs
1504 }
1505 return info_dict
1506
1507 @staticmethod
1508 def evWrongRank(check=False, **kwargs):
1509 all_ranks = (1, 2, 3, 4, 5)
1510
1511 # Make a list of incorrect ranks
1512 assert 'op' in kwargs
1513 op = kwargs['op']
1514 rmin, rmax = op['rank']
1515 rank_range = range(rmin, rmax + 1)
1516 incorrect_ranks = list(set(all_ranks) - set(rank_range))
Matthew Haddonc2025212021-10-08 21:21:05 +01001517 # Remove small incorrect ranks to avoid index errors
1518 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
Matthew Haddon848efb42021-09-09 12:30:53 +01001519 # Set minimum incorrect rank to 3 to avoid index error
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001520 if op['op'] in [Op.RESIZE]:
Matthew Haddon848efb42021-09-09 12:30:53 +01001521 incorrect_ranks = [3, 5]
1522
1523 error_name = ErrorIf.WrongRank
1524 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1525 error_result = False
1526 error_reason = "Rank not supported for this operator"
1527
1528 if check:
1529 input_shape = kwargs['input_shape']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001530
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001531 if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4:
Matthew Haddon848efb42021-09-09 12:30:53 +01001532 error_result = True
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001533 elif op['op'] == Op.FULLY_CONNECTED and len(input_shape) != 2:
1534 error_result = True
1535 elif op['op'] == Op.MATMUL and len(input_shape) != 3:
1536 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001537 else:
1538 if len(input_shape) not in rank_range:
1539 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001540
1541 info_dict = {
1542 "error_name": error_name,
1543 "error_result": error_result,
1544 "error_reason": error_reason,
1545 "param_reqs": param_reqs
1546 }
1547 return info_dict
1548
1549 @staticmethod
1550 def evWrongInputList(check=False, **kwargs):
1551 error_name = ErrorIf.WrongInputList
1552 param_reqs = {"rank": None, "dtype": None, "shape": None}
1553 error_result = False
1554 error_reason = "Op input list does not match expected input"
1555
1556 if check:
1557 op = kwargs['op']
1558 input_list = kwargs['input_list']
1559 num_operands = kwargs['num_operands']
Kevin Chengfe392ce2021-10-18 21:51:55 +00001560 if len(input_list) != num_operands:
1561 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001562
1563 info_dict = {
1564 "error_name": error_name,
1565 "error_result": error_result,
1566 "error_reason": error_reason,
1567 "param_reqs": param_reqs
1568 }
1569 return info_dict
1570
1571 @staticmethod
1572 def evWrongOutputList(check=False, **kwargs):
1573 error_name = ErrorIf.WrongOutputList
1574 param_reqs = {"rank": None, "dtype": None, "shape": None}
1575 error_result = False
1576 error_reason = "Op output list does not match expected output"
1577
1578 if check:
1579 output_list = kwargs['output_list']
1580 # Note this will be incorrect if an operator returns more than one output
1581 if len(output_list) != 1:
1582 error_result = True
1583
1584 info_dict = {
1585 "error_name": error_name,
1586 "error_result": error_result,
1587 "error_reason": error_reason,
1588 "param_reqs": param_reqs
1589 }
1590 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01001591
1592 @staticmethod
1593 def evMaxDimExceeded(check=False, **kwargs):
1594 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01001595 param_reqs = {
1596 "rank": [4,4],
1597 "dtype": [DType.INT8],
1598 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
1599 }
Matthew Haddone86fd342021-09-07 16:12:21 +01001600 error_result = False
1601 error_reason = "At least one maximum dimension is larger than 16384"
1602
1603 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001604 input_shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001605 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
1606 if ((input_shape[1] > 16384) or
1607 (input_shape[2] > 16384) or
1608 (output_shape[0] > 16384) or
1609 (output_shape[1] > 16384)):
1610 error_result = True
1611
1612 info_dict = {
1613 "error_name": error_name,
1614 "error_result": error_result,
1615 "error_reason": error_reason,
1616 "param_reqs": param_reqs
1617 }
1618 return info_dict
1619
1620 @staticmethod
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001621 def evBatchMismatch(check=False, **kwargs):
1622 error_name = ErrorIf.BatchMismatch
1623 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1624 error_result = False
1625 error_reason = "Input batch size not equal to output batch size"
1626
1627 assert 'op' in kwargs
1628 op = kwargs['op']
1629 rmin, rmax = op['rank']
1630 rank_range = range(rmin, rmax + 1)
1631
1632 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001633 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001634 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1635
1636 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
1637 error_result = True
1638
1639 info_dict = {
1640 "error_name": error_name,
1641 "error_result": error_result,
1642 "error_reason": error_reason,
1643 "param_reqs": param_reqs
1644 }
1645 return info_dict
1646
1647 @staticmethod
1648 def evChannelMismatch(check=False, **kwargs):
1649 error_name = ErrorIf.ChannelMismatch
1650 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1651 error_result = False
1652 error_reason = "Input channel size not equal to output channel size"
1653
1654 assert 'op' in kwargs
1655 op = kwargs['op']
1656 rmin, rmax = op['rank']
1657 rank_range = range(rmin, rmax + 1)
1658
1659 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001660 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001661 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1662 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
1663 error_result = True
1664
1665 info_dict = {
1666 "error_name": error_name,
1667 "error_result": error_result,
1668 "error_reason": error_reason,
1669 "param_reqs": param_reqs
1670 }
1671 return info_dict
1672
1673 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001674 def evStrideSmallerEqualZero(check=False, **kwargs):
1675 error_name = ErrorIf.StrideSmallerEqualZero
1676 param_reqs = {"rank": None, "dtype": None, "shape": None}
1677 error_result = False
1678 error_reason = "Stride value smaller than or equal zero"
1679
1680 if check:
1681 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001682 output_dtype = kwargs['output_dtype']
1683 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
1684 stride = kwargs['stride'] # Work around wrong input/output type tests
1685 elif output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001686 stride = kwargs['stride_fp']
Matthew Haddon848efb42021-09-09 12:30:53 +01001687 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1688 stride = kwargs['stride_fp'] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01001689 else:
1690 stride = kwargs['stride']
1691
1692 if min(stride) <= 0:
1693 error_result = True
1694
1695 info_dict = {
1696 "error_name": error_name,
1697 "error_result": error_result,
1698 "error_reason": error_reason,
1699 "param_reqs": param_reqs
1700 }
1701 return info_dict
1702
1703 @staticmethod
1704 def evStrideLargerEqualMax(check=False, **kwargs):
1705 error_name = ErrorIf.StrideLargerEqualMax
1706 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1707 error_result = False
1708 error_reason = "Stride value larger than or equal to maximum value"
1709
1710 if check:
1711 shift = kwargs['shift']
1712 input_dtype = kwargs['input_dtype']
1713 stride = kwargs['stride']
1714 if input_dtype in [DType.INT8, DType.INT16]:
1715 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1716 error_result = True
1717 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1718 error_result = True
1719
1720 info_dict = {
1721 "error_name": error_name,
1722 "error_result": error_result,
1723 "error_reason": error_reason,
1724 "param_reqs": param_reqs
1725 }
1726 return info_dict
1727
1728
1729 @staticmethod
1730 def evStrideLargerDimension(check=False, **kwargs):
1731 error_name = ErrorIf.StrideLargerDimension
1732 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1733 error_result = False
1734 error_reason = "Stride value larger than or equal to H/W dimension"
1735
1736 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001737 shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001738 input_dtype = kwargs['input_dtype']
1739 stride = kwargs['stride_fp']
1740
1741 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1742 error_result = True
1743
1744 info_dict = {
1745 "error_name": error_name,
1746 "error_result": error_result,
1747 "error_reason": error_reason,
1748 "param_reqs": param_reqs
1749 }
1750 return info_dict
1751
1752
1753 @staticmethod
1754 def evOffsetSmallerEqualMin(check=False, **kwargs):
1755 error_name = ErrorIf.OffsetSmallerEqualMin
1756 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1757 error_result = False
1758 error_reason = "Offset value smaller than or equal to minimum value"
1759
1760 if check:
1761 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001762 output_dtype = kwargs['output_dtype']
1763 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001764 offset = kwargs['offset_fp']
1765 else:
1766 offset = kwargs['offset']
1767
1768 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
1769 error_result = True
1770 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
1771 error_result = True
1772
1773 info_dict = {
1774 "error_name": error_name,
1775 "error_result": error_result,
1776 "error_reason": error_reason,
1777 "param_reqs": param_reqs
1778 }
1779 return info_dict
1780
1781 @staticmethod
1782 def evOffsetLargerEqualMax(check=False, **kwargs):
1783 error_name = ErrorIf.OffsetLargerEqualMax
1784 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1785 error_result = False
1786 error_reason = "Offset value larger than or equal to maximum value"
1787
1788 if check:
1789 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001790 output_dtype = kwargs['output_dtype']
1791 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001792 offset = kwargs['offset_fp']
1793 else:
1794 offset = kwargs['offset']
1795
1796 if shift >= 0:
1797 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
1798 error_result = True
1799
1800 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
1801 error_result = True
1802 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
1803 error_result = True
1804
1805 info_dict = {
1806 "error_name": error_name,
1807 "error_result": error_result,
1808 "error_reason": error_reason,
1809 "param_reqs": param_reqs
1810 }
1811 return info_dict
1812
1813 @staticmethod
1814 def evShiftNotZero(check=False, **kwargs):
1815 error_name = ErrorIf.ShiftNotZero
1816 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1817 error_result = False
1818 error_reason = "Shift value must be zero for float input"
1819
1820 if check:
1821 shift = kwargs['shift']
1822 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001823 output_dtype = kwargs['output_dtype']
1824 if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
Matthew Haddone86fd342021-09-07 16:12:21 +01001825 error_result = True
1826
1827 info_dict = {
1828 "error_name": error_name,
1829 "error_result": error_result,
1830 "error_reason": error_reason,
1831 "param_reqs": param_reqs
1832 }
1833 return info_dict
1834
1835
1836 @staticmethod
1837 def evShiftSmallerOne(check=False, **kwargs):
1838 error_name = ErrorIf.ShiftSmallerOne
1839 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1840 error_result = False
1841 error_reason = "Shift value smaller than one"
1842
1843 if check:
1844 shift = kwargs['shift']
1845 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001846 output_dtype = kwargs['output_dtype']
1847 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001848 error_result = True
1849
1850 info_dict = {
1851 "error_name": error_name,
1852 "error_result": error_result,
1853 "error_reason": error_reason,
1854 "param_reqs": param_reqs
1855 }
1856 return info_dict
1857
1858 @staticmethod
1859 def evShiftLargerEleven(check=False, **kwargs):
1860 error_name = ErrorIf.ShiftLargerEleven
1861 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1862 error_result = False
1863 error_reason = "Shift value larger than eleven"
1864
1865 if check:
1866 shift = kwargs['shift']
1867 if shift > 11:
1868 error_result = True
1869
1870 info_dict = {
1871 "error_name": error_name,
1872 "error_result": error_result,
1873 "error_reason": error_reason,
1874 "param_reqs": param_reqs
1875 }
1876 return info_dict
1877
1878
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001879 @staticmethod
1880 def evRankMismatch(check=False, **kwargs):
1881 error_name = ErrorIf.RankMismatch
1882 param_reqs = {"rank": None, "dtype": None, "shape": None}
1883 error_result = False
1884 error_reason = "Input Rank does not match output rank"
1885
1886 if check:
1887 input1_shape = kwargs['input1'].shape
1888 input2_shape = kwargs['input2'].shape
1889 output_shape = kwargs['result_tensor'].shape
1890 if (len(input1_shape) != len(output_shape)) or (len(input2_shape) != len(output_shape)):
1891 error_result = True
1892
1893 info_dict = {
1894 "error_name": error_name,
1895 "error_result": error_result,
1896 "error_reason": error_reason,
1897 "param_reqs": param_reqs
1898 }
1899 return info_dict
1900
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001901 @staticmethod
1902 def evInputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001903 op = kwargs['op']
1904 inputDtypes = op['types'].copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001905 # If inputDtypes is a list then only the first two elements are INT8 inputs
1906 if isinstance(inputDtypes, list):
1907 inputDtypes = inputDtypes[2:]
1908
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001909 if DType.INT8 in inputDtypes:
1910 inputDtypes.remove(DType.INT8)
1911 if DType.UINT8 in inputDtypes:
1912 inputDtypes.remove(DType.UINT8)
1913
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001914 error_name = ErrorIf.InputZeroPointNotZero
1915 param_reqs = {
1916 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001917 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001918 "shape": None
1919 }
1920 error_result = False
1921 error_reason = "Input DType not INT8 and zero point not 0"
1922
1923 if check:
1924 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01001925 if isinstance(kwargs['qinfo'], tuple):
1926 qinfo = kwargs['qinfo']
1927 input_zero_point = qinfo[0]
1928 else:
1929 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1930 qinfo = kwargs['qinfo'].ints
1931 input_zero_point = qinfo[0][1]
1932
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001933 if op['op'] == Op.MATMUL:
1934 input1_dtype = kwargs['input_dtype']
1935 input2_dtype = kwargs['input2_dtype']
1936 qinfo = kwargs['qinfo'].ints
1937 input1_zero_point = qinfo[0][1]
1938 input2_zero_point = qinfo[1][1]
1939 if (input1_dtype != DType.INT8 and input1_zero_point != 0) or (input2_dtype != DType.INT8 and input2_zero_point != 0):
1940 error_result = True
1941 else:
1942 if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
1943 error_result = True
1944
1945 info_dict = {
1946 "error_name": error_name,
1947 "error_result": error_result,
1948 "error_reason": error_reason,
1949 "param_reqs": param_reqs
1950 }
1951 return info_dict
1952
1953
1954 @staticmethod
1955 def evWeightZeroPointNotZero(check=False, **kwargs):
1956 op = kwargs['op']
1957
1958 # exclude inputs with INT8 weights
1959 inputDtypes = [t for t in op['types']
1960 if not isinstance(t, list) or t[1] != DType.INT8]
1961
1962 error_name = ErrorIf.WeightZeroPointNotZero
1963 param_reqs = {
1964 "rank": None,
1965 "dtype": inputDtypes,
1966 "shape": None
1967 }
1968 error_result = False
1969 error_reason = "Weight DType not INT8 and zero point not 0"
1970
1971 if check:
1972 weight_dtype = kwargs['weight_dtype']
1973 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
1974 qinfo = kwargs['qinfo'].ints
1975 weight_zero_point = qinfo[1][1]
1976 if weight_dtype != DType.INT8 and weight_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001977 error_result = True
1978
1979 info_dict = {
1980 "error_name": error_name,
1981 "error_result": error_result,
1982 "error_reason": error_reason,
1983 "param_reqs": param_reqs
1984 }
1985 return info_dict
1986
1987
1988 @staticmethod
1989 def evOutputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001990 op = kwargs['op']
1991 inputDtypes = op['types'].copy()
1992 if DType.INT8 in inputDtypes:
1993 inputDtypes.remove(DType.INT8)
1994 if DType.UINT8 in inputDtypes:
1995 inputDtypes.remove(DType.UINT8)
1996
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001997 error_name = ErrorIf.OutputZeroPointNotZero
1998 param_reqs = {
1999 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002000 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002001 "shape": None
2002 }
2003 error_result = False
2004 error_reason = "Output DType not INT8 and zero point not 0"
2005
2006 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002007 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01002008 output_dtype = kwargs['output_dtype']
2009 if isinstance(kwargs['qinfo'], tuple):
2010 qinfo = kwargs['qinfo']
2011 output_zero_point = qinfo[1]
2012 else:
2013 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
2014 qinfo = kwargs['qinfo'].ints
2015 output_zero_point = qinfo[1][1]
2016 if op['op'] == Op.AVG_POOL2D:
2017 if input_dtype != DType.INT8 and output_zero_point != 0:
2018 error_result = True
2019 elif output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002020 error_result = True
2021
2022 info_dict = {
2023 "error_name": error_name,
2024 "error_result": error_result,
2025 "error_reason": error_reason,
2026 "param_reqs": param_reqs
2027 }
2028 return info_dict
2029
Matthew Haddond6ce7252021-09-29 15:35:44 +01002030 @staticmethod
2031 def evAxisSmallerZero(check=False, **kwargs):
2032 error_name = ErrorIf.AxisSmallerZero
2033 param_reqs = {"rank": None, "dtype": None, "shape": None}
2034 error_result = False
2035 error_reason = "Axis smaller than zero"
2036
2037 if check:
2038 axis = kwargs['axis']
2039 if axis < 0:
2040 error_result = True
2041
2042 info_dict = {
2043 "error_name": error_name,
2044 "error_result": error_result,
2045 "error_reason": error_reason,
2046 "param_reqs": param_reqs
2047 }
2048 return info_dict
2049
2050
2051 @staticmethod
2052 def evAxisLargerRank(check=False, **kwargs):
2053 error_name = ErrorIf.AxisLargerRank
2054 param_reqs = {"rank": None, "dtype": None, "shape": None}
2055 error_result = False
2056 error_reason = "Axis larger than rank"
2057
2058 if check:
2059 axis = kwargs['axis']
2060 shape = kwargs['input_shape']
2061 if axis > len(shape):
2062 error_result = True
2063
2064 info_dict = {
2065 "error_name": error_name,
2066 "error_result": error_result,
2067 "error_reason": error_reason,
2068 "param_reqs": param_reqs
2069 }
2070 return info_dict
2071
2072
2073 @staticmethod
2074 def evShapeOfAxisNotOne(check=False, **kwargs):
2075 error_name = ErrorIf.ShapeOfAxisNotOne
2076 param_reqs = {"rank": None, "dtype": None, "shape": None}
2077 error_result = False
2078 error_reason = "shape[axis] is not equal to 1"
2079
2080 if check:
2081 axis = kwargs['axis']
2082 shape = kwargs['output_shape']
2083 if (0 <= axis < len(shape)) and shape[axis] != 1:
2084 error_result = True
2085
2086 info_dict = {
2087 "error_name": error_name,
2088 "error_result": error_result,
2089 "error_reason": error_reason,
2090 "param_reqs": param_reqs
2091 }
2092 return info_dict
2093
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002094
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002095 @staticmethod
2096 def evPadSmallerZero(check=False, **kwargs):
2097 error_name = ErrorIf.PadSmallerZero
2098 param_reqs = {"rank": None, "dtype": None, "shape": None}
2099 error_result = False
2100 error_reason = "At least one pad is smaller than zero"
2101
2102 if check:
Matthew Haddone807aae2021-10-11 18:12:58 +01002103 op = kwargs['op']
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002104 pad = kwargs['pad']
Matthew Haddone807aae2021-10-11 18:12:58 +01002105 if op['op'] == Op.PAD:
2106 for padding in pad:
2107 if min(padding) < 0:
2108 error_result = True
2109 else:
2110 if min(pad) < 0:
2111 error_result = True
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002112
2113 info_dict = {
2114 "error_name": error_name,
2115 "error_result": error_result,
2116 "error_reason": error_reason,
2117 "param_reqs": param_reqs
2118 }
2119 return info_dict
2120
2121
2122 @staticmethod
2123 def evPadLargerEqualKernel(check=False, **kwargs):
2124 error_name = ErrorIf.PadLargerEqualKernel
2125 param_reqs = {"rank": None, "dtype": None, "shape": None}
2126 error_result = False
2127 error_reason = "At least one pad is larger than kernel dimension"
2128
2129 if check:
2130 pad = kwargs['pad']
2131 kernel = kwargs['kernel']
2132 if min(pad) > 0 and min(kernel) > 1:
2133 if pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]:
2134 error_result = True
2135
2136 info_dict = {
2137 "error_name": error_name,
2138 "error_result": error_result,
2139 "error_reason": error_reason,
2140 "param_reqs": param_reqs
2141 }
2142 return info_dict
2143
2144 @staticmethod
2145 def evPoolingOutputShapeMismatch(check=False, **kwargs):
2146 error_name = ErrorIf.PoolingOutputShapeMismatch
2147 param_reqs = {"rank": None, "dtype": None, "shape": None}
2148 error_result = False
2149 error_reason = "Mismatch between output shape provided and expected output shape"
2150
2151 if check:
2152 pad = kwargs['pad']
2153 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
2154
2155 kernel = kwargs['kernel']
2156 kernel_y, kernel_x = kernel[0], kernel[1]
2157
2158 input_shape = kwargs['input_shape']
2159 IH, IW = input_shape[1], input_shape[2]
2160
2161 output_shape = kwargs['output_shape']
2162 OH, OW = output_shape[1], output_shape[2]
2163
2164 stride = kwargs['stride']
2165 stride_y, stride_x = stride[0], stride[1]
2166
2167 # calculate correct height, width dimensions
2168 if stride_x != 0 and stride_y != 0:
2169 y_correct = (IH + pad_top + pad_bottom + stride_y - kernel_y) // stride_y
2170 x_correct = (IW + pad_left + pad_right + stride_x - kernel_x) // stride_x
2171
2172 # ensure parameters are valid
2173 params_valid = (min(kernel) >= 1 and min(stride) >= 1 and min(pad) >= 0
2174 and not (pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]))
2175
2176 if params_valid and (OH != y_correct or OW != x_correct):
2177 error_result = True
2178
2179 info_dict = {
2180 "error_name": error_name,
2181 "error_result": error_result,
2182 "error_reason": error_reason,
2183 "param_reqs": param_reqs
2184 }
2185 return info_dict
2186
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002187 @staticmethod
2188 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
2189 error_name = ErrorIf.ArgmaxOutputShapeMismatch
2190 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2191 error_result = False
2192 error_reason = "Mismatch between output shape provided and expected output shape"
2193
2194 if check:
2195 output_shape = kwargs['output_shape']
2196 input_shape = kwargs['input_shape']
2197 axis = kwargs['axis']
2198
2199 dimension_match = True
2200 axis_shift = 0
2201
2202 # Check that rank is correct before trying to check dimensions
2203 if (len(input_shape) - 1) == len(output_shape):
2204 for i in range(len(input_shape)):
2205 if i == axis:
2206 axis_shift = 1
2207 continue
2208 if input_shape[i] != output_shape[i - axis_shift]:
2209 dimension_match = False
2210
2211 if not dimension_match:
2212 error_result = True
2213
2214 info_dict = {
2215 "error_name": error_name,
2216 "error_result": error_result,
2217 "error_reason": error_reason,
2218 "param_reqs": param_reqs
2219 }
2220 return info_dict
2221
2222 @staticmethod
2223 def evArgmaxOutputRankMismatch(check=False, **kwargs):
2224 error_name = ErrorIf.ArgmaxOutputRankMismatch
2225 param_reqs = {"rank": None, "dtype": None, "shape": None}
2226 error_result = False
2227 error_reason = "Mismatch between output shape provided and expected output shape"
2228
2229 if check:
2230 output_shape = kwargs['output_shape']
2231 input_shape = kwargs['input_shape']
2232 axis = kwargs['axis']
2233 valid_params = axis >= 0 and axis < len(input_shape)
2234
2235 if valid_params and (len(input_shape) - 1) != len(output_shape):
2236 error_result = True
2237
2238 info_dict = {
2239 "error_name": error_name,
2240 "error_result": error_result,
2241 "error_reason": error_reason,
2242 "param_reqs": param_reqs
2243 }
2244 return info_dict
2245
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002246
2247 @staticmethod
2248 def evKernelSmallerOne(check=False, **kwargs):
2249 error_name = ErrorIf.KernelSmallerOne
2250 param_reqs = {"rank": None, "dtype": None, "shape": None}
2251 error_result = False
2252 error_reason = "At least one kernel dimension is smaller than zero"
2253
2254 if check:
2255 kernel = kwargs['kernel']
2256 if min(kernel) < 1:
2257 error_result = True
2258
2259 info_dict = {
2260 "error_name": error_name,
2261 "error_result": error_result,
2262 "error_reason": error_reason,
2263 "param_reqs": param_reqs
2264 }
2265 return info_dict
2266
2267 @staticmethod
2268 def evStrideSmallerOne(check=False, **kwargs):
2269 error_name = ErrorIf.StrideSmallerOne
2270 param_reqs = {"rank": None, "dtype": None, "shape": None}
2271 error_result = False
2272 error_reason = "At least one stride dimension is smaller than zero"
2273
2274 if check:
2275 stride = kwargs['stride']
2276 if min(stride) < 1:
2277 error_result = True
2278
2279 info_dict = {
2280 "error_name": error_name,
2281 "error_result": error_result,
2282 "error_reason": error_reason,
2283 "param_reqs": param_reqs
2284 }
2285 return info_dict
2286
Matthew Haddonc2025212021-10-08 21:21:05 +01002287 @staticmethod
2288 def evScaleTrue(check=False, **kwargs):
2289 error_name = ErrorIf.ScaleTrue
2290 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
2291 error_result = False
2292 error_reason = "Scale set to true but input type is INT48"
2293
2294 if check:
2295 input_dtype = kwargs['input_dtype']
2296 scale32 = kwargs['scale32']
2297 if scale32 and input_dtype == DType.INT48:
2298 error_result = True
2299
2300 info_dict = {
2301 "error_name": error_name,
2302 "error_result": error_result,
2303 "error_reason": error_reason,
2304 "param_reqs": param_reqs
2305 }
2306 return info_dict
2307
2308 @staticmethod
2309 def evScaleNotTrue(check=False, **kwargs):
2310 error_name = ErrorIf.ScaleNotTrue
2311 param_reqs = {"rank": None, "dtype": None, "shape": None}
2312 error_result = False
2313 error_reason = "Scale set to false but double round set to true"
2314
2315 if check:
2316 scale32 = kwargs['scale32']
2317 double_round = kwargs['double_round']
2318 if not scale32 and double_round:
2319 error_result = True
2320
2321 info_dict = {
2322 "error_name": error_name,
2323 "error_result": error_result,
2324 "error_reason": error_reason,
2325 "param_reqs": param_reqs
2326 }
2327 return info_dict
2328
Matthew Haddone807aae2021-10-11 18:12:58 +01002329 @staticmethod
2330 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
2331 error_name = ErrorIf.TensorSizeInputOutputMismatch
2332 param_reqs = {"rank": None, "dtype": None, "shape": None}
2333 error_result = False
2334 error_reason = "Input tensor size does not match output tensor size"
2335
2336 if check:
2337 input_shape = kwargs['input_shape']
2338 output_shape = kwargs['output_shape']
2339 input_size = np.prod(input_shape)
2340 output_size = np.prod(output_shape)
2341 if input_size != output_size:
2342 error_result = True
2343
2344 info_dict = {
2345 "error_name": error_name,
2346 "error_result": error_result,
2347 "error_reason": error_reason,
2348 "param_reqs": param_reqs
2349 }
2350 return info_dict
2351
2352 @staticmethod
2353 def evStartSmallerZero(check=False, **kwargs):
2354 error_name = ErrorIf.StartSmallerZero
2355 param_reqs = {"rank": None, "dtype": None, "shape": None}
2356 error_result = False
2357 error_reason = "Starting point smaller than zero"
2358
2359 if check:
2360 input_shape = kwargs['input_shape']
2361 start = kwargs['start']
2362 rank = len(input_shape)
2363 if len(start) == rank:
2364 for index in range(rank):
2365 if start[index] < 0:
2366 error_result = True
2367
2368 info_dict = {
2369 "error_name": error_name,
2370 "error_result": error_result,
2371 "error_reason": error_reason,
2372 "param_reqs": param_reqs
2373 }
2374 return info_dict
2375
2376
2377 @staticmethod
2378 def evSizeSmallerEqualZero(check=False, **kwargs):
2379 error_name = ErrorIf.SizeSmallerEqualZero
2380 param_reqs = {"rank": None, "dtype": None, "shape": None}
2381 error_result = False
2382 error_reason = "Size smaller than or equal to zero"
2383
2384 if check:
2385 input_shape = kwargs['input_shape']
2386 size = kwargs['size']
2387 rank = len(input_shape)
2388 if len(size) == rank:
2389 for index in range(rank):
2390 if size[index] <= 0:
2391 error_result = True
2392
2393 info_dict = {
2394 "error_name": error_name,
2395 "error_result": error_result,
2396 "error_reason": error_reason,
2397 "param_reqs": param_reqs
2398 }
2399 return info_dict
2400
2401
2402 @staticmethod
2403 def evStartSizeOutsideBounds(check=False, **kwargs):
2404 error_name = ErrorIf.StartSizeOutsideBounds
2405 param_reqs = {"rank": None, "dtype": None, "shape": None}
2406 error_result = False
2407 error_reason = "starting point plus size larger than input dimension"
2408
2409 if check:
2410 input_shape = kwargs['input_shape']
2411 start = kwargs['start']
2412 size = kwargs['size']
2413 rank = len(input_shape)
2414 if len(start) == rank and len(size) == rank:
2415 for index in range(rank):
2416 if start[index] + size[index] > input_shape[index]:
2417 error_result = True
2418
2419 info_dict = {
2420 "error_name": error_name,
2421 "error_result": error_result,
2422 "error_reason": error_reason,
2423 "param_reqs": param_reqs
2424 }
2425 return info_dict
2426
2427
2428 @staticmethod
2429 def evSizeOutputShapeMismatch(check=False, **kwargs):
2430 error_name = ErrorIf.SizeOutputShapeMismatch
2431 param_reqs = {"rank": None, "dtype": None, "shape": None}
2432 error_result = False
2433 error_reason = "Size does not match output dimension"
2434
2435 if check:
2436 input_shape = kwargs['input_shape']
2437 output_shape = kwargs['output_shape']
2438 size = kwargs['size']
2439 rank = len(input_shape)
2440 if len(size) == rank:
2441 for index in range(rank):
2442 if size[index] != output_shape[index]:
2443 error_result = True
2444
2445 info_dict = {
2446 "error_name": error_name,
2447 "error_result": error_result,
2448 "error_reason": error_reason,
2449 "param_reqs": param_reqs
2450 }
2451 return info_dict
2452
2453 @staticmethod
2454 def evInputSizeStartLengthMismatch(check=False, **kwargs):
2455 error_name = ErrorIf.InputSizeStartLengthMismatch
2456 param_reqs = {"rank": None, "dtype": None, "shape": None}
2457 error_result = False
2458 error_reason = "rank of input not equal to length of start or size"
2459
2460 if check:
2461 input_shape = kwargs['input_shape']
2462 start = kwargs['start']
2463 size = kwargs['size']
2464 rank = len(input_shape)
2465 if rank != len(start) or rank != len(size):
2466 error_result = True
2467
2468 info_dict = {
2469 "error_name": error_name,
2470 "error_result": error_result,
2471 "error_reason": error_reason,
2472 "param_reqs": param_reqs
2473 }
2474 return info_dict
2475
2476 @staticmethod
2477 def evIndexOutsideBounds(check=False, **kwargs):
2478 error_name = ErrorIf.IndexOutsideBounds
2479 param_reqs = {"rank": None, "dtype": None, "shape": None}
2480 error_result = False
2481 error_reason = "Index outside of allowed bounds"
2482
2483 if check:
2484 input_shape = kwargs['input_shape']
2485 perms = kwargs['perms']
2486 rank = len(input_shape)
2487
2488 for index in perms:
2489 if index < 0 or index > rank:
2490 error_result = True
2491
2492 info_dict = {
2493 "error_name": error_name,
2494 "error_result": error_result,
2495 "error_reason": error_reason,
2496 "param_reqs": param_reqs
2497 }
2498 return info_dict
2499
2500 @staticmethod
2501 def evIndexUsedTwice(check=False, **kwargs):
2502 error_name = ErrorIf.IndexUsedTwice
2503 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2504 error_result = False
2505 error_reason = "Index used multiple times"
2506
2507 if check:
2508 input_shape = kwargs['input_shape']
2509 perms = kwargs['perms']
2510 rank = len(input_shape)
2511
2512 unique_indices = []
2513 for index in perms:
2514 if index in unique_indices:
2515 error_result = True
2516 else:
2517 unique_indices.append(index)
2518
2519 info_dict = {
2520 "error_name": error_name,
2521 "error_result": error_result,
2522 "error_reason": error_reason,
2523 "param_reqs": param_reqs
2524 }
2525 return info_dict
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002526
2527
Matthew Haddonb724efc2021-08-25 16:40:29 +01002528class TosaInvalidValidator:
2529
2530 @staticmethod
2531 def ivWrongDataTypeOrModeResize(**kwargs):
2532 input_dtype = kwargs["input_dtype"]
2533 args = kwargs["args"]
2534 mode = args[0]
2535 stride = args[1]
2536 stride_fp = args[4]
2537 output_dtype = args[8]
2538
2539 if mode == ResizeMode.BILINEAR:
2540 # Invalid output data type / Invalid input datatype
2541 return (
2542 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
2543 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
2544 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
2545 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2546 )
2547 elif mode == ResizeMode.NEAREST:
2548 # Invalid output data type / Invalid input datatype
2549 return (
2550 (input_dtype != output_dtype) or
2551 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2552 )
2553 else:
2554 # Invalid resize mode
2555 return True
2556
2557 @staticmethod
2558 def ivBadStride(**kwargs):
2559 input_dtype = kwargs["input_dtype"]
2560 args = kwargs["args"]
2561 stride_x = args[1][0]
2562 stride_y = args[1][1]
2563 stride_fp_x = args[4][0]
2564 stride_fp_y = args[4][1]
2565
2566 if input_dtype == DType.FLOAT:
2567 if stride_fp_x <= 0 or stride_fp_y <= 0:
2568 # Negative or zero stride
2569 return True
2570 else:
2571 if stride_x <= 0 or stride_y <= 0:
2572 # Negative or zero stride
2573 return True
2574 return False
2575
2576
Matthew Haddonb724efc2021-08-25 16:40:29 +01002577 @staticmethod
2578 def ivHeightWidthSmallerZero(**kwargs):
2579 opName = kwargs['opName']
2580
2581 inputShapes = kwargs['shapeList']
2582 input = inputShapes[0]
2583 if not opName.endswith("pool2d"):
2584 filter = inputShapes[1]
2585
2586 args = kwargs['args']
2587 strides = args[0]
2588 padding = args[1]
2589 dilations = args[2]
2590 if opName.endswith("pool2d"):
2591 kernel = args[2]
2592
2593 if opName.startswith('conv2d'):
2594 h = (
2595 input[1]
2596 - filter[1]
2597 - (filter[1] - 1) * (dilations[0] - 1)
2598 + padding[0]
2599 + padding[1]
2600 ) // strides[0] + 1
2601
2602 w = (
2603 input[2]
2604 - filter[2]
2605 - (filter[2] - 1) * (dilations[1] - 1)
2606 + padding[2]
2607 + padding[3]
2608 ) // strides[1] + 1
2609 elif opName.startswith("depthwise_conv2d"):
2610 h = (
2611 input[1]
2612 - filter[0]
2613 - (filter[0] - 1) * (dilations[0] - 1)
2614 + padding[0]
2615 + padding[1]
2616 ) // strides[0] + 1
2617
2618 w = (
2619 input[2]
2620 - filter[1]
2621 - (filter[1] - 1) * (dilations[1] - 1)
2622 + padding[2]
2623 + padding[3]
2624 ) // strides[1] + 1
2625 elif opName.endswith("pool2d"):
2626 h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
2627 w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
2628 else:
2629 assert False, "Unrecognized Op"
2630
2631 if h <= 0 or w <= 0:
2632 # Invalid parameter combination
2633 return True
2634 return False
2635
2636 @staticmethod
2637 def ivNonPositiveOutputShape(**kwargs):
2638 args = kwargs['args']
2639 output_shape = args[3]
2640 if output_shape[1] <= 0 or output_shape[2] <= 0:
2641 # Negative output shape
2642 return True
2643 return False
2644
2645
Kevin Cheng550ccc52021-03-03 11:21:43 -08002646
Eric Kunzee5e26762020-10-13 16:11:07 -07002647class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002648 # Maximum rank of tensor supported by test generator.
2649 TOSA_TENSOR_MAX_RANK = 6
2650
Eric Kunzee5e26762020-10-13 16:11:07 -07002651 def __init__(self, args):
2652 self.args = args
2653 self.basePath = args.output_dir
2654 self.random_seed = args.random_seed
2655 self.ser = None
2656 self.rng = np.random.default_rng(self.random_seed)
2657 self.createDynamicOpLists()
2658 self.initOpListDefaults()
2659 self.quantGen = TosaQuantGen()
2660 # Force makeShape to do a specific starting shape
2661 self.targetted_shape = None
2662
2663 def createSerializer(self, opName, testPath):
2664 self.testPath = os.path.join(opName, testPath)
2665
2666 fullPath = os.path.join(self.basePath, self.testPath)
2667 os.makedirs(fullPath, exist_ok=True)
2668 self.ser = ts.TosaSerializer(fullPath)
2669
2670 def getSerializer(self):
2671 return self.ser
2672
2673 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002674 with open(
2675 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
2676 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07002677 fd.write(self.ser.serialize())
2678
Kevin Cheng550ccc52021-03-03 11:21:43 -08002679 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
2680 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07002681
Matthew Haddon74567092021-07-16 15:38:20 +01002682 def resetRNG(self, seed=None):
2683 if seed == None:
2684 seed = self.random_seed + 1
2685 self.rng = np.random.default_rng(seed)
2686
Eric Kunzee5e26762020-10-13 16:11:07 -07002687 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07002688 if dtype == DType.BOOL:
2689 np_dt = np.bool
2690 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07002691 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07002692 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07002693 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002694 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002695 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
2696 elif dtype == DType.UINT8:
2697 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002698 elif dtype == DType.INT16:
2699 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
2700 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002701 return np.int32(
2702 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
2703 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002704 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002705 return np.int64(
2706 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
2707 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002708 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002709 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002710 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002711 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002712
Kevin Cheng989cb052021-04-28 16:29:44 -07002713 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07002714 placeholders = []
2715
Kevin Cheng989cb052021-04-28 16:29:44 -07002716 assert len(shape_list) == len(dtype_list)
2717
2718 for idx, shape in enumerate(shape_list):
2719 arr = self.getRandTensor(shape, dtype_list[idx])
2720 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002721
2722 return placeholders
2723
Kevin Cheng989cb052021-04-28 16:29:44 -07002724 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07002725 consts = []
2726
Kevin Cheng989cb052021-04-28 16:29:44 -07002727 assert len(shape_list) == len(dtype_list)
2728
2729 for idx, shape in enumerate(shape_list):
2730 arr = self.getRandTensor(shape, dtype_list[idx])
2731 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002732
2733 return consts
2734
2735 def makeShape(self, rank):
2736 if self.targetted_shape:
2737 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002738 return np.int32(
2739 self.rng.integers(
2740 low=self.args.tensor_shape_range[0],
2741 high=self.args.tensor_shape_range[1],
2742 size=rank,
2743 )
2744 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002745
2746 def setTargetShape(self, shape):
2747 self.targetted_shape = shape
2748
2749 def randInt(self, low=0, high=256):
2750 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
2751
2752 def getRandNumberDType(self, dtype):
2753 if dtype == DType.FLOAT:
2754 return self.rng.random()
2755 elif dtype == DType.BOOL:
2756 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07002757 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07002758 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07002759 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07002760 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002761 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07002762 elif dtype == DType.INT16:
2763 low, high = (-32768, 32768)
2764 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002765 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07002766 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002767 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07002768 # Special size
2769 return np.int64(self.rng.integers(low, high, size=1))[0]
2770 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002771 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002772
2773 return np.int32(self.rng.integers(low, high, size=1))[0]
2774
2775 def shapeStr(self, shape):
2776
2777 sStr = []
2778 # Convert to strings
2779 for i in shape:
2780 sStr.append(str(i))
2781
Kevin Cheng550ccc52021-03-03 11:21:43 -08002782 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002783
2784 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07002785 if isinstance(t, list):
2786 assert len(t) >= 2
2787 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002788 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002789 if t == DType.BOOL:
2790 return "b"
2791 elif t == DType.INT4:
2792 return "i4"
2793 elif t == DType.INT8:
2794 return "i8"
2795 elif t == DType.UINT8:
2796 return "u8"
2797 elif t == DType.INT16:
2798 return "i16"
2799 elif t == DType.INT32:
2800 return "i32"
2801 elif t == DType.INT48:
2802 return "i48"
2803 elif t == DType.FLOAT:
2804 return "float"
2805 else:
2806 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002807
2808 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002809 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08002810 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07002811 return 4
2812 elif t == DType.INT8:
2813 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08002814 elif t == DType.UINT8:
2815 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07002816 elif t == DType.INT16:
2817 return 16
2818 elif t == DType.INT32:
2819 return 32
2820 elif t == DType.INT48:
2821 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +01002822 elif t == DType.FLOAT:
2823 return 32
2824 elif t == DType.BOOL:
2825 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002826 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002827 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002828
2829 # Argument generators
2830 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
2831 # Where the string descriptor is used to generate the test name and
2832 # The build_fcn_arg_list is expanded and passed to the operator test
2833 # build function
2834
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002835 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
2836 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
2837
Matthew Haddon848efb42021-09-09 12:30:53 +01002838 # build_placeholder returns an int, ABS/other ops does not
2839 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002840 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
2841 return result_tens
2842 elif op['op'] == Op.IDENTITY:
2843 self.ser.addOperator(op['op'], a.name, result_tens.name, None, qinfo)
2844 return result_tens
2845
2846 # Ensure new output type has correct qinfo
2847 if error_name == ErrorIf.WrongOutputType:
2848 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
2849 qinfo = ts.TosaSerializerQuantInfo()
2850 qinfo.UnaryQuantInfo(
2851 TosaQuantGen.getQinfo(self, a.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
2852 )
2853
2854 # Invalidate Input/Output list for error if checks.
2855 input_list = [a.name]
2856 output_list = [result_tens.name]
2857 pCount, cCount = op["operands"]
2858 num_operands = pCount + cCount
2859 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2860
2861 TosaErrorValidator.evValidateErrorIfs(
2862 self.ser,
2863 validator_fcns,
2864 error_name,
2865 op=op,
2866 input_dtype=a.dtype,
2867 output_dtype=result_tens.dtype,
2868 qinfo = qinfo,
2869 result_tensor = result_tens,
2870 input_list=input_list,
2871 output_list=output_list,
2872 num_operands=num_operands,
2873 )
2874
2875 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002876 return result_tens
2877
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002878 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
2879 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
2880
2881
2882 # Invalidate Input/Output list for error if checks.
2883 input_list = [a.name, b.name]
2884 output_list = [result_tens.name]
2885 pCount, cCount = op["operands"]
2886 num_operands = pCount + cCount
2887 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2888
2889 TosaErrorValidator.evValidateErrorIfs(
2890 self.ser,
2891 validator_fcns,
2892 error_name,
2893 op=op,
2894 input1 = a,
2895 input2 = b,
2896 input_dtype = a.dtype,
2897 output_dtype = result_tens.dtype,
2898 result_tensor = result_tens,
2899 input_list=input_list,
2900 output_list=output_list,
2901 num_operands=num_operands,
2902 )
2903
2904 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07002905 return result_tens
2906
2907 def build_binary_nonbroadcast(self, op, a, b):
2908 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002909 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002910 return result_tens
2911
Kevin Chengaee1fac2020-11-11 13:54:06 -08002912 def build_arithmetic_right_shift(self, op, a, b, round):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002913 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002914
2915 attr = ts.TosaSerializerAttribute()
2916 attr.ArithmeticRightShiftAttribute(round)
2917
Matthew Haddon848efb42021-09-09 12:30:53 +01002918 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002919 return result_tens
2920
2921 def build_mul(self, op, a, b, shift):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002922 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Eric Kunzee5e26762020-10-13 16:11:07 -07002923
2924 # Special for multiply:
2925 # Force the result to INT32 for INT types
2926 if a.dtype != DType.FLOAT:
2927 result_tens.setDtype(DType.INT32)
2928
Kevin Chengaee1fac2020-11-11 13:54:06 -08002929 attr = ts.TosaSerializerAttribute()
2930 attr.MulAttribute(shift)
2931
Matthew Haddon848efb42021-09-09 12:30:53 +01002932 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002933 return result_tens
2934
Kevin Chengfe392ce2021-10-18 21:51:55 +00002935 def build_table(self, op, a, table):
2936 result_tens = OutputShaper.tableOp(self.ser, a)
Eric Kunzee5e26762020-10-13 16:11:07 -07002937
Kevin Chengfe392ce2021-10-18 21:51:55 +00002938 attr = ts.TosaSerializerAttribute()
2939 attr.TableAttribute(table)
2940
2941 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002942
2943 return result_tens
2944
2945 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07002946 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002947 self.ser.addOperator(op['op'], [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002948 return result_tens
2949
2950 def build_comparison(self, op, a, b):
2951 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002952 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002953 return result_tens
2954
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002955 def build_argmax(self, op, a, axis, validator_fcns, error_name):
2956 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
2957
2958 # Invalidate Input/Output list for error if checks.
2959 input_list = [a.name]
2960 output_list = [result_tens.name]
2961 pCount, cCount = op["operands"]
2962 num_operands = pCount + cCount
2963 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2964
2965 TosaErrorValidator.evValidateErrorIfs(
2966 self.ser,
2967 validator_fcns,
2968 error_name,
2969 op=op,
2970 axis=axis,
2971 input_shape = a.shape,
2972 input_dtype = a.dtype,
2973 output_shape = result_tens.shape,
2974 output_dtype = result_tens.dtype,
2975 result_tensor = result_tens,
2976 input_list=input_list,
2977 output_list=output_list,
2978 num_operands=num_operands,
2979 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002980
2981 attr = ts.TosaSerializerAttribute()
2982 attr.AxisAttribute(axis)
2983
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002984 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002985 return result_tens
2986
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002987 def build_pool2d(self, op, input, stride, pad, kernel, validator_fcns=None, error_name=None, qinfo=None):
2988 result_tens = OutputShaper.pool2dOp(self.ser, self.rng, input, kernel, stride, pad, error_name)
2989
2990 # Ensure new output type has correct qinfo
2991 if error_name == ErrorIf.WrongInputType:
2992 if input.dtype not in [DType.INT8, DType.UINT8]:
2993 qinfo = ts.TosaSerializerQuantInfo()
2994 qinfo.UnaryQuantInfo(
2995 TosaQuantGen.getQinfo(self, input.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
2996 )
2997
2998 # Invalidate Input/Output list for error if checks.
2999 input_list = [input.name]
3000 output_list = [result_tens.name]
3001 pCount, cCount = op["operands"]
3002 num_operands = pCount + cCount
3003 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3004
3005 TosaErrorValidator.evValidateErrorIfs(
3006 self.ser,
3007 validator_fcns,
3008 error_name,
3009 op=op,
3010 input_shape=input.shape,
3011 input_dtype=input.dtype,
3012 output_shape=result_tens.shape,
3013 output_dtype=result_tens.dtype,
3014 kernel=kernel,
3015 stride=stride,
3016 pad=pad,
3017 qinfo = qinfo,
3018 result_tensor = result_tens,
3019 input_list=input_list,
3020 output_list=output_list,
3021 num_operands=num_operands,
3022 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003023
3024 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003025 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07003026
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003027 self.ser.addOperator(op['op'], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003028 return result_tens
3029
3030 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003031 assert len(padding) == 4
3032 result_tens = OutputShaper.conv2dOp(
3033 self.ser, ifm, filter, strides, padding, dilations
3034 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003035
3036 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003037 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07003038
Kevin Cheng550ccc52021-03-03 11:21:43 -08003039 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003040 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003041 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003042 return result_tens
3043
Kevin Cheng1533b852021-09-01 12:51:58 -07003044 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
3045 assert len(padding) == 6
3046 result_tens = OutputShaper.conv3dOp(
3047 self.ser, ifm, filter, strides, padding, dilations
3048 )
3049
3050 attr = ts.TosaSerializerAttribute()
3051 attr.ConvAttribute(padding, strides, dilations)
3052
3053 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003054 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng1533b852021-09-01 12:51:58 -07003055 )
3056 return result_tens
3057
Kevin Cheng550ccc52021-03-03 11:21:43 -08003058 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07003059 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003060 ):
3061 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07003062 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
3063
3064 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003065 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003066
Kevin Cheng550ccc52021-03-03 11:21:43 -08003067 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003068 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003069 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003070 return result_tens
3071
Kevin Cheng550ccc52021-03-03 11:21:43 -08003072 def build_depthwise_conv2d(
3073 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
3074 ):
3075 result_tens = OutputShaper.depthwiseConv2dOp(
3076 self.ser, ifm, filter, strides, padding, dilations
3077 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003078
3079 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003080 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07003081
Kevin Cheng550ccc52021-03-03 11:21:43 -08003082 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003083 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003084 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003085 return result_tens
3086
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003087 def build_fully_connected(self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None):
3088 result_tens = OutputShaper.fullyConnectedOp(self.ser, self.rng, ifm, filter, error_name)
3089
3090 # Invalidate Input/Output list for error if checks.
3091 input_list = [ifm.name, filter.name, bias.name]
3092 output_list = [result_tens.name]
3093 pCount, cCount = op["operands"]
3094 num_operands = pCount + cCount
3095 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3096
3097 TosaErrorValidator.evValidateErrorIfs(
3098 self.ser,
3099 validator_fcns,
3100 error_name,
3101 op=op,
3102 input_shape=ifm.shape,
3103 input_dtype=ifm.dtype,
3104 weight_dtype=filter.dtype,
3105 output_shape=result_tens.shape,
3106 output_dtype=result_tens.dtype,
3107 qinfo = qinfo,
3108 result_tensor = result_tens,
3109 input_list=input_list,
3110 output_list=output_list,
3111 num_operands=num_operands,
3112 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003113
Kevin Cheng550ccc52021-03-03 11:21:43 -08003114 self.ser.addOperator(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003115 op['op'], input_list, output_list, None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003116 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003117 return result_tens
3118
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003119 def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
3120 result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
3121
3122 # Invalidate Input/Output list for error if checks.
3123 input_list = [a.name, b.name]
3124 output_list = [result_tens.name]
3125 pCount, cCount = op["operands"]
3126 num_operands = pCount + cCount
3127 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3128
3129 TosaErrorValidator.evValidateErrorIfs(
3130 self.ser,
3131 validator_fcns,
3132 error_name,
3133 op=op,
3134 input_shape=a.shape,
3135 input_dtype=a.dtype,
3136 input2_shape=b.shape,
3137 input2_dtype=b.dtype,
3138 output_shape=result_tens.shape,
3139 output_dtype=result_tens.dtype,
3140 qinfo = qinfo,
3141 result_tensor = result_tens,
3142 input_list=input_list,
3143 output_list=output_list,
3144 num_operands=num_operands,
3145 )
3146
3147 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003148 return result_tens
3149
Matthew Haddond6ce7252021-09-29 15:35:44 +01003150 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
3151 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
3152
3153 # Invalidate Input/Output list for error if checks.
3154 input_list = [a.name]
3155 output_list = [result_tens.name]
3156 pCount, cCount = op["operands"]
3157 num_operands = pCount + cCount
3158 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3159
3160 TosaErrorValidator.evValidateErrorIfs(
3161 self.ser,
3162 validator_fcns,
3163 error_name,
3164 op=op,
3165 axis = axis,
3166 input_shape = a.shape,
3167 output_shape = result_tens.shape,
3168 input_dtype = a.dtype,
3169 output_dtype = result_tens.dtype,
3170 result_tensor = result_tens,
3171 input_list=input_list,
3172 output_list=output_list,
3173 num_operands=num_operands,
3174 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003175
3176 attr = ts.TosaSerializerAttribute()
3177 attr.AxisAttribute(axis)
3178
Matthew Haddond6ce7252021-09-29 15:35:44 +01003179 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003180 return result_tens
3181
3182 def build_clamp(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003183 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003184
3185 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01003186 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07003187
3188 if a.dtype == DType.FLOAT:
3189 attr.ClampAttribute(0, 0, min(v), max(v))
3190 else:
3191 attr.ClampAttribute(min(v), max(v), 0, 0)
3192
Matthew Haddon848efb42021-09-09 12:30:53 +01003193 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003194 return result_tens
3195
3196 def build_leaky_relu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003197 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003198 attr = ts.TosaSerializerAttribute()
3199
3200 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
3201
Matthew Haddon848efb42021-09-09 12:30:53 +01003202 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003203 return result_tens
3204
3205 # Needs an additional type/input
3206 def build_prelu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003207 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003208
Matthew Haddon848efb42021-09-09 12:30:53 +01003209 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003210 return result_tens
3211
Eric Kunzee5e26762020-10-13 16:11:07 -07003212 def build_sigmoid(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003213 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01003214 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003215 return result_tens
3216
3217 def build_tanh(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003218 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01003219 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003220 return result_tens
3221
Matthew Haddon818ab902021-07-27 09:12:49 +01003222 def build_concat(self, op, *a):
Kevin Cheng93a16282021-08-31 16:14:03 -07003223 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01003224
3225 # To store variable length list of input tensors we need to store axis along with it
3226 axis = a[-1]
3227 a = a[:-1]
3228
3229 result_tens = OutputShaper.concatOp(self.ser, axis, *a)
Eric Kunzee5e26762020-10-13 16:11:07 -07003230
3231 attr = ts.TosaSerializerAttribute()
3232 attr.AxisAttribute(axis)
3233
Matthew Haddon818ab902021-07-27 09:12:49 +01003234 input_tensor_names = []
3235 for tensor in a:
3236 input_tensor_names.append(tensor.name)
3237
Matthew Haddon848efb42021-09-09 12:30:53 +01003238 self.ser.addOperator(op['op'], input_tensor_names, [result_tens.name], attr)
3239 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07003240
Kevin Chengfe392ce2021-10-18 21:51:55 +00003241 def build_pad(self, op, a, padding, pad_const_int, pad_const_float, validator_fcns=None, error_name=None, qinfo=None):
Matthew Haddone807aae2021-10-11 18:12:58 +01003242 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003243
Kevin Chengfe392ce2021-10-18 21:51:55 +00003244 attr = ts.TosaSerializerAttribute()
3245 attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07003246
Matthew Haddone807aae2021-10-11 18:12:58 +01003247 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00003248 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01003249 output_list = [result_tens.name]
3250 pCount, cCount = op["operands"]
3251 num_operands = pCount + cCount
3252 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3253
3254 TosaErrorValidator.evValidateErrorIfs(
3255 self.ser,
3256 validator_fcns,
3257 error_name,
3258 op=op,
3259 input_shape = a.shape,
3260 output_shape = result_tens.shape,
3261 input_dtype = a.dtype,
3262 output_dtype = result_tens.dtype,
3263 pad=padding,
3264 qinfo=qinfo,
3265 result_tensor = result_tens,
3266 input_list=input_list,
3267 output_list=output_list,
3268 num_operands=num_operands,
3269 )
3270
Kevin Cheng550ccc52021-03-03 11:21:43 -08003271 self.ser.addOperator(
Kevin Chengfe392ce2021-10-18 21:51:55 +00003272 op['op'], input_list, output_list, attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003273 )
Matthew Haddone86fd342021-09-07 16:12:21 +01003274 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07003275
Matthew Haddone807aae2021-10-11 18:12:58 +01003276 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
3277 result_tens = OutputShaper.reshapeOp(self.ser, self.rng, a, newShape, error_name)
3278
3279 # Invalidate Input/Output list for error if checks.
3280 input_list = [a.name]
3281 output_list = [result_tens.name]
3282 pCount, cCount = op["operands"]
3283 num_operands = pCount + cCount
3284 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3285
3286 TosaErrorValidator.evValidateErrorIfs(
3287 self.ser,
3288 validator_fcns,
3289 error_name,
3290 op=op,
3291 input_shape = a.shape,
3292 output_shape = result_tens.shape,
3293 input_dtype = a.dtype,
3294 output_dtype = result_tens.dtype,
3295 result_tensor = result_tens,
3296 input_list=input_list,
3297 output_list=output_list,
3298 num_operands=num_operands,
3299 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003300
3301 attr = ts.TosaSerializerAttribute()
3302 attr.ReshapeAttribute(newShape)
3303
Matthew Haddone807aae2021-10-11 18:12:58 +01003304 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003305 return result_tens
3306
3307 def build_reverse(self, op, a, axis):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003308 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003309
3310 attr = ts.TosaSerializerAttribute()
3311 attr.AxisAttribute(axis)
3312
Matthew Haddon848efb42021-09-09 12:30:53 +01003313 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003314 return result_tens
3315
Matthew Haddone807aae2021-10-11 18:12:58 +01003316 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
3317 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003318
Kevin Chengfe392ce2021-10-18 21:51:55 +00003319 attr = ts.TosaSerializerAttribute()
3320 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07003321
Matthew Haddone807aae2021-10-11 18:12:58 +01003322 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00003323 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01003324 output_list = [result_tens.name]
3325 pCount, cCount = op["operands"]
3326 num_operands = pCount + cCount
3327 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3328
3329 TosaErrorValidator.evValidateErrorIfs(
3330 self.ser,
3331 validator_fcns,
3332 error_name,
3333 op=op,
3334 input_shape = a.shape,
3335 output_shape = result_tens.shape,
3336 perms=perms,
3337 input_dtype = a.dtype,
3338 output_dtype = result_tens.dtype,
3339 result_tensor = result_tens,
3340 input_list=input_list,
3341 output_list=output_list,
3342 num_operands=num_operands,
3343 )
3344
3345
Kevin Chengfe392ce2021-10-18 21:51:55 +00003346 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003347 return result_tens
3348
Matthew Haddone807aae2021-10-11 18:12:58 +01003349 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
3350 result_tens = OutputShaper.sliceOp(self.ser, self.rng, a, start, size, error_name)
3351
3352 # Invalidate Input/Output list for error if checks.
3353 input_list = [a.name]
3354 output_list = [result_tens.name]
3355 pCount, cCount = op["operands"]
3356 num_operands = pCount + cCount
3357 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3358
3359 TosaErrorValidator.evValidateErrorIfs(
3360 self.ser,
3361 validator_fcns,
3362 error_name,
3363 op=op,
3364 input_shape = a.shape,
3365 output_shape = result_tens.shape,
3366 input_dtype = a.dtype,
3367 output_dtype = result_tens.dtype,
3368 start=start,
3369 size=size,
3370 result_tensor = result_tens,
3371 input_list=input_list,
3372 output_list=output_list,
3373 num_operands=num_operands,
3374 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003375
3376 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01003377 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07003378
Matthew Haddone807aae2021-10-11 18:12:58 +01003379 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003380 return result_tens
3381
3382 def build_tile(self, op, a, multiples):
3383 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
3384
3385 attr = ts.TosaSerializerAttribute()
3386 attr.TileAttribute(multiples)
3387
Matthew Haddon848efb42021-09-09 12:30:53 +01003388 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003389 return result_tens
3390
Kevin Cheng77d0f762020-11-24 10:26:32 -08003391 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07003392
3393 # Create a new indicies tensor
3394 # here with data that doesn't exceed the dimensions of the values tensor
3395
Kevin Cheng550ccc52021-03-03 11:21:43 -08003396 K = values.shape[1] # K
3397 W = self.randInt(
3398 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
3399 ) # W
3400 indicies_arr = np.int32(
3401 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
3402 ) # (N, W)
3403 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003404
Kevin Cheng77d0f762020-11-24 10:26:32 -08003405 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07003406
Matthew Haddon848efb42021-09-09 12:30:53 +01003407 self.ser.addOperator(op['op'], [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003408
3409 return result_tens
3410
Kevin Cheng77d0f762020-11-24 10:26:32 -08003411 def build_scatter(self, op, values_in, input):
3412
3413 # Create a new indicies tensor
3414 # here with data that doesn't exceed the dimensions of the values_in tensor
3415
Kevin Cheng550ccc52021-03-03 11:21:43 -08003416 K = values_in.shape[1] # K
3417 W = input.shape[1] # W
3418 indicies_arr = np.int32(
3419 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
3420 ) # (N, W)
3421 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08003422
3423 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
3424
Kevin Cheng550ccc52021-03-03 11:21:43 -08003425 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003426 op['op'], [values_in.name, indicies.name, input.name], [result_tens.name]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003427 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08003428
3429 return result_tens
3430
Matthew Haddon848efb42021-09-09 12:30:53 +01003431
Kevin Cheng550ccc52021-03-03 11:21:43 -08003432 def build_resize(
3433 self,
3434 op,
3435 input,
3436 mode,
3437 stride,
3438 offset,
3439 shift,
3440 stride_fp,
3441 offset_fp,
3442 output_dims,
3443 input_dtype,
3444 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01003445 validator_fcns,
3446 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003447 ):
3448 result_tens = OutputShaper.resizeOp(
3449 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003450 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003451 input,
3452 mode,
3453 stride,
3454 offset,
3455 shift,
3456 stride_fp,
3457 offset_fp,
3458 output_dims,
3459 input_dtype,
3460 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01003461 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08003462 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003463
Matthew Haddon848efb42021-09-09 12:30:53 +01003464 # Invalidate Input/Output list for error if checks.
3465 input_list = [input.name]
3466 output_list = [result_tens.name]
3467 pCount, cCount = op["operands"]
3468 num_operands = pCount + cCount
3469 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
Matthew Haddone86fd342021-09-07 16:12:21 +01003470
Matthew Haddon848efb42021-09-09 12:30:53 +01003471 TosaErrorValidator.evValidateErrorIfs(
3472 self.ser,
3473 validator_fcns,
3474 error_name,
3475 op=op,
3476 mode=mode,
3477 shift=shift,
3478 input_dtype=input_dtype,
3479 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003480 input_shape=input.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01003481 output_shape=output_dims,
3482 offset=offset,
3483 offset_fp=offset_fp,
3484 stride=stride,
3485 stride_fp=stride_fp,
3486 input_list=input_list,
3487 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003488 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01003489 num_operands=num_operands,
3490 )
Matthew Haddone86fd342021-09-07 16:12:21 +01003491
Eric Kunzee5e26762020-10-13 16:11:07 -07003492 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08003493
Kevin Cheng550ccc52021-03-03 11:21:43 -08003494 attr.ResizeAttribute(
3495 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
3496 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003497
Matthew Haddon848efb42021-09-09 12:30:53 +01003498 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003499 return result_tens
3500
3501 def build_identityn(self, op, val, val2):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003502 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, None)
3503 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, None)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003504 self.ser.addOperator(
3505 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
3506 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003507 return result_tens
3508
Kevin Cheng17e92022021-10-01 14:33:33 -07003509 def build_const(self, op, val):
3510 self.ser.addOutputTensor(val)
3511 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07003512
3513 # Type Conversion
3514 def build_cast(self, op, val, out_dtype):
3515 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01003516 self.ser.addOperator(op['op'], [val.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003517 return result_tens
3518
Matthew Haddonc2025212021-10-08 21:21:05 +01003519 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel, validator_fcns, error_name):
Eric Kunzee5e26762020-10-13 16:11:07 -07003520 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
3521
3522 if per_channel:
3523 nc = val.shape[-1]
3524 else:
3525 nc = 1
3526
3527 in_type_width = self.typeWidth(val.dtype)
3528 out_type_width = self.typeWidth(out_dtype)
3529
Kevin Cheng3a478572021-01-22 17:21:02 -08003530 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003531 input_zp = self.randInt(-128, 128)
3532 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07003533 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003534 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07003535 in_type_width = in_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01003536 elif error_name == ErrorIf.InputZeroPointNotZero:
3537 input_zp = self.randInt(-128, 128)
3538 if input_zp == 0:
3539 input_zp = input_zp + self.rng.integers(1, 10)
3540 in_type_width = in_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003541 else:
3542 input_zp = 0
3543
Kevin Cheng3a478572021-01-22 17:21:02 -08003544 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003545 output_zp = self.randInt(-128, 128)
3546 out_type_width = out_type_width + 1
3547 elif out_dtype == DType.UINT8:
3548 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07003549 out_type_width = out_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01003550 elif error_name == ErrorIf.OutputZeroPointNotZero:
3551 output_zp = self.randInt(-128, 128)
3552 if output_zp == 0:
3553 output_zp = output_zp + self.rng.integers(1, 10)
3554 out_type_width = out_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003555 else:
3556 output_zp = 0
3557
3558 # Calculate scale based on:
3559 # scale = a *(2^output_width)/(2^input_width))
3560
3561 a = np.float32(self.rng.random(size=[nc]))
3562 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
3563
3564 if scale32:
3565 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01003566 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07003567 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
3568 else:
3569 # Cap the scaling at 2^15 - 1 for scale16
3570 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
3571
Kevin Cheng550ccc52021-03-03 11:21:43 -08003572 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07003573
3574 multiplier_arr = np.int32(np.zeros(shape=[nc]))
3575 shift_arr = np.int32(np.zeros(shape=[nc]))
3576
3577 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003578 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
3579 scale_arr[i], scale32
3580 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003581
Kevin Cheng550ccc52021-03-03 11:21:43 -08003582 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07003583
Matthew Haddonc2025212021-10-08 21:21:05 +01003584 # Invalidate Input/Output list for error if checks.
3585 input_list = [val.name]
3586 output_list = [result_tens.name]
3587 pCount, cCount = op["operands"]
3588 num_operands = pCount + cCount
3589 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3590
3591 qinfo = (input_zp, output_zp)
3592 TosaErrorValidator.evValidateErrorIfs(
3593 self.ser,
3594 validator_fcns,
3595 error_name,
3596 op=op,
3597 input_dtype=val.dtype,
3598 output_dtype=out_dtype,
3599 input_shape=val.shape,
3600 qinfo=qinfo,
3601 scale32 = scale32,
3602 double_round = double_round,
3603 input_list=input_list,
3604 output_list=output_list,
3605 result_tensor=result_tens,
3606 num_operands=num_operands,
3607 )
3608
Eric Kunzee5e26762020-10-13 16:11:07 -07003609 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003610 attr.RescaleAttribute(
3611 input_zp,
3612 output_zp,
3613 multiplier_arr,
3614 shift_arr,
3615 scale32,
3616 double_round,
3617 per_channel,
3618 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003619
Matthew Haddonc2025212021-10-08 21:21:05 +01003620 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003621 return result_tens
3622
3623 def build_cond_if_const(self, op, then_tens, else_tens, cond):
3624 # For cond_if with constants, we're supplied with then/else tensors that we ignore
3625 # (except for the generated shap) and the condition. Build Then/Else blocks
3626 # and fill them with const nodes for the body.
3627
3628 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003629 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07003630
3631 # Make then/else tensors
3632 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01003633 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
3634 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003635
3636 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08003637 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07003638
3639 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003640 then_block = "THEN_BLOCK"
3641 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003642 attr = ts.TosaSerializerAttribute()
3643 attr.CondIfAttribute(then_block, else_block)
3644
3645 # Finally, build the op and the two blocks
Matthew Haddon848efb42021-09-09 12:30:53 +01003646 self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003647
3648 self.ser.startBasicBlock(then_block)
3649 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003650 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003651 self.ser.addOutputTensor(then_tens)
3652
3653 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003654 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003655 self.ser.addOutputTensor(else_tens)
3656
3657 return result_tens
3658
3659 def build_cond_if_binary(self, op, a, b, cond):
3660 # For cond_if with a binary op in the then/else blocks, take a and b and
3661 # alternately add or subtract them based on the condition
3662
3663 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003664 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07003665
Kevin Cheng550ccc52021-03-03 11:21:43 -08003666 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003667
3668 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003669 then_block = "THEN_BLOCK"
3670 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003671 attr = ts.TosaSerializerAttribute()
3672 attr.CondIfAttribute(then_block, else_block)
3673
3674 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003675 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003676 op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08003677 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003678
Les Bell6040b4d2021-10-11 12:50:31 +01003679 if a.dtype in (DType.FLOAT, DType.INT32):
3680 then_op, else_op = Op.ADD, Op.SUB
3681 elif a.dtype in (DType.INT8, DType.INT16):
3682 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
3683 else:
3684 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07003685
Les Bell6040b4d2021-10-11 12:50:31 +01003686 for block, op in ((then_block, then_op), (else_block, else_op)):
3687 self.ser.startBasicBlock(block)
3688 self.ser.addInputTensor(a)
3689 self.ser.addInputTensor(b)
3690 tens = self.ser.addOutput(a.shape, a.dtype)
3691 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003692
3693 return result_tens
3694
3695 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003696 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07003697
Kevin Cheng550ccc52021-03-03 11:21:43 -08003698 cond_block = "COND_BLOCK"
3699 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003700
3701 attr = ts.TosaSerializerAttribute()
3702 attr.WhileLoopAttribute(cond_block, body_block)
3703
3704 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003705 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003706 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08003707 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07003708
3709 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003710 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
3711 a_out = self.ser.addIntermediate(a.shape, a.dtype)
3712 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003713
3714 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08003715 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003716 op['op'],
Kevin Cheng550ccc52021-03-03 11:21:43 -08003717 [iter.name, a.name, acc.name],
3718 [iter_out.name, a_out.name, acc_out.name],
3719 attr,
3720 )
Kevin Chengb227ae52021-09-02 13:43:17 -07003721 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07003722
3723 # COND block (input: iter, output: cond_tens )
3724 self.ser.startBasicBlock(cond_block)
3725 self.ser.addInputTensor(iter)
3726 self.ser.addInputTensor(a)
3727 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003728 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
3729 cond_tens = self.ser.addOutput([], DType.BOOL)
3730 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003731
3732 # BODY block (input: a, acc, iter, output: a, acc, iter)
3733 # Note that local intermediate tensors need to be declared here for the outputs
3734 self.ser.startBasicBlock(body_block)
3735 self.ser.addInputTensor(iter)
3736 self.ser.addInputTensor(a)
3737 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003738 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
3739 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
3740 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003741 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
3742 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
3743 self.ser.addOutputTensor(iter_body_out)
3744 self.ser.addOutputTensor(a)
3745 self.ser.addOutputTensor(acc_body_out)
3746
3747 return acc_out
3748
Matthew Haddon1c00b712021-10-01 15:51:03 +01003749 def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
3750 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
3751 default_test_rank_range = range(1, 5)
3752 if not shapeFilter:
3753 shapeFilter = [None]
3754
3755 # Calculate the filters based on what is requested and what the operator allows
3756 rmin, rmax = op["rank"]
3757 if rankFilter is not None:
3758 cleanRankFilter = []
3759 # Ensure rankFilter values are allowed by operator
3760 for rank in rankFilter:
3761 if rank >= rmin and rank <= rmax:
3762 cleanRankFilter.append(rank)
3763 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01003764 # Ensure default behaviour is bounded by default range or by operator,
3765 # whichever is the smaller range of ranks.
3766 opRankRange = range(rmin, rmax + 1)
3767 cleanRankFilter = opRankRange if len(opRankRange) <= len(default_test_rank_range) else default_test_rank_range
Matthew Haddon1c00b712021-10-01 15:51:03 +01003768 else:
3769 cleanRankFilter = range(rmin, rmax + 1)
3770
3771 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003772
Matthew Haddon1c00b712021-10-01 15:51:03 +01003773 if dtypeFilter is not None:
3774 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01003775 # Create list of operator dtypes filtered by requested dtypes
3776 for dtype in dtypes:
3777 if dtype in dtypeFilter or (isinstance(dtype, list) and dtype[0] in dtypeFilter):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003778 cleanDtypeFilter.append(dtype)
3779 else:
3780 cleanDtypeFilter = dtypes
3781
3782 if testType == 'positive':
3783 filterDict = {
3784 'shapeFilter': shapeFilter,
3785 'rankFilter': cleanRankFilter,
3786 'dtypeFilter': cleanDtypeFilter
3787 }
3788 return filterDict
3789 elif testType == 'negative':
Matthew Haddone807aae2021-10-11 18:12:58 +01003790 if validator is not None:
3791 validator_info = validator(check=False, op=op)
3792 else:
3793 return None
3794
Matthew Haddon1c00b712021-10-01 15:51:03 +01003795 error_arguments = validator_info['param_reqs']
3796
3797 #Set parameters as required
3798 if error_arguments['rank'] != None:
3799 rankFilter = error_arguments['rank']
3800 else:
3801 rankFilter = cleanRankFilter
3802
3803 if error_arguments['dtype'] != None:
3804 dtypeFilter = error_arguments['dtype']
3805 else:
3806 dtypeFilter = cleanDtypeFilter
3807
3808 if error_arguments['shape'] != None:
3809 shapeFilter = error_arguments['shape']
3810 else:
3811 shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small
3812
3813 filterDict = {
3814 'shapeFilter': shapeFilter,
3815 'rankFilter': rankFilter,
3816 'dtypeFilter': dtypeFilter
3817 }
3818 return filterDict
3819
3820
Kevin Cheng550ccc52021-03-03 11:21:43 -08003821 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01003822 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08003823 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003824
3825 try:
3826 op = self.TOSA_OP_LIST[opName]
3827 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003828 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003829
3830 # Initialize a new random number generator
3831 self.rng = np.random.default_rng(self.random_seed)
3832
Kevin Cheng550ccc52021-03-03 11:21:43 -08003833 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003834
Eric Kunzee5e26762020-10-13 16:11:07 -07003835 # Test list consists of a tuple of:
3836 # (opName, testNameStr, dtype, shapeList, argumentsList)
3837 testList = []
Matthew Haddon1c00b712021-10-01 15:51:03 +01003838 if testType == 'negative' and "error_if_validators" in op:
3839 error_if_validators = op["error_if_validators"]
3840 else:
3841 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07003842
Matthew Haddon1c00b712021-10-01 15:51:03 +01003843 for validator in error_if_validators:
3844 if validator is not None:
3845 error_name = validator(check=False, op=op)['error_name']
Matthew Haddon1c00b712021-10-01 15:51:03 +01003846 else:
3847 error_name = None
3848
3849 filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator)
Matthew Haddone807aae2021-10-11 18:12:58 +01003850 if filterDict == None:
3851 return []
Matthew Haddon1c00b712021-10-01 15:51:03 +01003852 cleanRankFilter = filterDict['rankFilter']
3853 cleanDtypeFilter = filterDict['dtypeFilter']
3854 cleanShapeFilter = filterDict['shapeFilter']
3855 #print(f"Filters: S {shapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
3856
3857 for r in cleanRankFilter:
Kevin Cheng1533b852021-09-01 12:51:58 -07003858 if opName.startswith("conv3d"):
3859 assert r == 5, "conv3d test must have input rank == 5"
Matthew Haddon1c00b712021-10-01 15:51:03 +01003860 for t in cleanDtypeFilter:
3861 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01003862 # Filter out by rank
3863 if shape is not None and len(shape) != r:
3864 continue
Matthew Haddon74567092021-07-16 15:38:20 +01003865 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003866 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003867
Matthew Haddon74567092021-07-16 15:38:20 +01003868 shapeStr = self.shapeStr(shapeList[0])
3869 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07003870
Matthew Haddon74567092021-07-16 15:38:20 +01003871 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
3872 argList = []
3873 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01003874 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003875 else:
Matthew Haddon74567092021-07-16 15:38:20 +01003876 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07003877
Matthew Haddon74567092021-07-16 15:38:20 +01003878 for argStr, args in argList:
Matthew Haddon1c00b712021-10-01 15:51:03 +01003879 if testType == 'positive':
3880 if argStr:
3881 testStr = "{}_{}_{}_{}".format(
3882 opName, shapeStr, typeStr, argStr
3883 )
3884 else:
3885 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
3886 elif testType == 'negative':
Matthew Haddone86fd342021-09-07 16:12:21 +01003887 if argStr:
3888 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
3889 opName, error_name, shapeStr, typeStr, argStr
3890 )
3891 else:
3892 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003893
3894 testList.append((opName, testStr, t, error_name, shapeList, args))
3895
3896 if testType == 'positive':
3897 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
3898 if "invalid_test_validators" in op:
3899 invalid_test_validators = op["invalid_test_validators"]
3900 clean_testList = []
3901 for test in testList:
3902 for validator_fcn in invalid_test_validators:
3903 remove_test = False
3904 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
3905 remove_test = True
3906 if not remove_test:
3907 clean_testList.append(test)
3908 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07003909
3910 return testList
3911
Matthew Haddone86fd342021-09-07 16:12:21 +01003912
3913 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07003914 try:
3915 op = self.TOSA_OP_LIST[opName]
3916 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003917 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003918
3919 # Create a serializer
3920 self.createSerializer(opName, testStr)
3921
Kevin Cheng550ccc52021-03-03 11:21:43 -08003922 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003923 if "error_if_validators" in op:
3924 error_if_validators = op["error_if_validators"]
3925 else:
3926 error_if_validators = None
3927
Kevin Cheng550ccc52021-03-03 11:21:43 -08003928 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003929 num_operands = pCount + cCount
3930
3931 if isinstance(dtype_or_dtypeList, list):
3932 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07003933 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01003934 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003935 else:
3936 dtypeList = [dtype_or_dtypeList] * (num_operands)
3937
Kevin Cheng93a16282021-08-31 16:14:03 -07003938 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01003939 assert (
3940 len(shapeList) == num_operands
3941 ), "shapeList length {} must match number of operands {}".format(
3942 len(shapeList), num_operands
3943 )
3944 assert (
3945 len(dtypeList) == num_operands
3946 ), "dtypeList length {} must match number of operands {}".format(
3947 len(dtypeList), num_operands
3948 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003949
3950 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003951 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003952 except KeyError:
3953 qgen = None
3954
3955 # Build the random tensor operands and the test
3956 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08003957
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003958 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003959
3960 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003961 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003962 else:
3963 qinfo = None
3964
3965 try:
3966 if error_if_validators is None:
3967 if qinfo is not None:
3968 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
3969 else:
3970 resultName = build_fcn(self, op, *tens, *testArgs)
3971 else:
3972 if qinfo is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003973 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name, qinfo)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003974 else:
3975 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name)
3976 except TypeError as e:
3977 print(
3978 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
3979 build_fcn, tens, testArgs
3980 )
3981 )
3982 raise e
3983
3984 if resultName is None:
3985 print("Invalid ERROR_IF tests created")
3986
3987 # Save the serialized test
3988 self.serialize("test")
3989
3990
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003991 def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003992 pCount, cCount = op["operands"]
3993
3994 tens = []
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003995 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name == None:
Jeremy Johnsonef509a42021-09-07 13:59:47 +01003996 # Make sure the operation does not cause value saturation - where
3997 # the number wraps due to limited number of bits to store the answer
3998 assert (
3999 pCount == 2 and cCount == 0
4000 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01004001 placeholders = []
4002 add = (op["op"] == Op.ADD)
4003 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
4004 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
4005 if add:
4006 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
4007 else:
4008 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
4009
4010 # Work out the saturation limits
4011 max_i32 = (1 << 31)-1
4012 min_i32 = -(1 << 31)
4013 max_arr = np.full(shapeList[1], max_i32)
4014 min_arr = np.full(shapeList[1], min_i32)
4015
4016 # Find how much values exceed the maximum/minimums
4017 sat_max_arr = np.maximum(res_arr - max_arr, 0)
4018 sat_min_arr = np.minimum(res_arr - min_arr, 0)
4019
4020 if not add:
4021 # Swap saturation values and negate values as we need to perform opposite operations
4022 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
4023
4024 # Create new array of unsaturated values by clipping values as needed
4025 b_unsat_arr = b_arr
4026 if (sat_max_arr != 0).any():
4027 # Clip values that cause saturation
4028 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
4029 # Reduce axes in unsaturated tensor to match original tensor
4030 for axis, dim in enumerate(b_arr.shape):
4031 if dim != b_unsat_arr.shape[axis]:
4032 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
4033 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
4034
4035 if (sat_min_arr != 0).any():
4036 # Clip values that cause saturation
4037 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
4038 # Reduce axes in unsaturated tensor to match original tensor
4039 for axis, dim in enumerate(b_arr.shape):
4040 if dim != b_unsat_arr.shape[axis]:
4041 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
4042 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
4043
4044 placeholders.append(
4045 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
4046 )
4047 placeholders.append(
4048 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
4049 )
4050
4051 tens.extend(placeholders)
Jeremy Johnson8c06a652021-10-20 15:51:11 +01004052 elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[0] == DType.INT32:
4053 # Limit input tensors with cond_if_binary or while_loop to stop
4054 # saturation of add/sub ops
4055 pRemain = pCount
4056 placeholders = []
4057 for idx, shape in enumerate(shapeList[:]):
4058 arr = self.getRandTensor(shapeList[idx], DType.INT16)
4059 if pRemain > 0:
4060 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
4061 pRemain -= 1
4062 else:
4063 placeholders.append(self.ser.addConst(shape, dtypeList[idx], arr))
4064
4065 tens.extend(placeholders)
Jeremy Johnsonef509a42021-09-07 13:59:47 +01004066 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
4067 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004068 assert (
4069 pCount == 2 and cCount == 0
4070 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08004071
4072 placeholders = []
4073 for idx, shape in enumerate(shapeList[:]):
4074 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07004075 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08004076 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07004077 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08004078 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07004079 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08004080 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
4081 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004082 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08004083 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004084 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07004085 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08004086
4087 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01004088 elif op["op"] == Op.SELECT:
4089 # Set datatype of condition tensor to boolean
4090 dtypeList[0] = DType.BOOL
4091 tens.extend(
4092 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
4093 )
4094 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004095 elif op["op"] == Op.INTDIV and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004096 assert (
4097 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01004098 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004099
4100 placeholders = []
4101
Matthew Haddon459443c2021-08-23 16:43:13 +01004102 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004103 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07004104 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004105 while True:
4106 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
4107 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
4108
4109 if (divisor_arr == 0).any():
4110 continue
4111
Kevin Cheng47315e12021-05-13 17:41:28 -07004112 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004113 continue
4114
4115 break
4116
4117 placeholders.append(
4118 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
4119 )
4120 placeholders.append(
4121 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
4122 )
4123
4124 tens.extend(placeholders)
4125 elif op["op"] == Op.MUL:
4126 assert (
4127 pCount == 2 and cCount == 0
4128 ), "Op.MUL must have 2 placeholders, 0 consts"
4129
4130 if dtypeList[0] == DType.FLOAT:
4131 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
4132 else:
4133 placeholders = []
4134
4135 # Make sure multiply result in int32 range
4136 shift = testArgs[0]
4137 if dtypeList[0] == DType.INT8:
4138 num_bits = 8
4139 elif dtypeList[0] == DType.INT16:
4140 num_bits = 16
4141 elif dtypeList[0] == DType.INT32:
4142 num_bits = 32
4143 else:
4144 raise Exception("OpMul: invalid input dtype")
4145
4146 for idx, shape in enumerate(shapeList[:]):
4147 low = -(2 ** (num_bits - 1))
4148 high = (2 ** (num_bits - 1)) - 1
4149
4150 a_arr = np.int32(
4151 self.rng.integers(low=low, high=high, size=shapeList[0])
4152 )
4153 b_arr = np.int32(
4154 self.rng.integers(low=low, high=high, size=shapeList[1])
4155 )
4156
4157 i = 0
4158 while True:
4159
4160 a_arr_64 = a_arr.astype(np.int64)
4161 b_arr_64 = b_arr.astype(np.int64)
4162
4163 if shift > 0:
4164 rounding = 1 << (shift - 1)
4165 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
4166 else:
4167 result_arr = a_arr_64 * b_arr_64
4168
4169 if (result_arr > -(2 ** 31)).all() and (
4170 result_arr <= ((2 ** 31) - 1)
4171 ).all():
4172 break
4173
4174 i = i + 1
4175 a_arr = a_arr // 2
4176 b_arr = b_arr // 2
4177
4178 placeholders.append(
4179 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
4180 )
4181 placeholders.append(
4182 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
4183 )
4184
4185 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01004186 elif op["op"] == Op.CONCAT:
4187 count = len(shapeList) - self.args.num_const_inputs_concat
4188 if count < 1:
4189 count = 1
4190 if self.args.num_const_inputs_concat == 0:
4191 count = len(shapeList)
4192
4193 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
4194 tens.extend(
4195 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
4196 )
4197 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08004198 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07004199 tens.extend(
4200 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
4201 )
4202 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07004203
Matthew Haddon1c00b712021-10-01 15:51:03 +01004204 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07004205
4206 def createDynamicOpLists(self):
4207
4208 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07004209 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004210
Kevin Cheng1533b852021-09-01 12:51:58 -07004211 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004212 testName = "conv2d_{}x{}".format(k[0], k[1])
4213 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
4214 self.TOSA_OP_LIST[testName]["filter"] = k
4215 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07004216
Kevin Cheng550ccc52021-03-03 11:21:43 -08004217 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
4218 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
4219 "depthwise_conv2d_TEMPLATE"
4220 ].copy()
4221 self.TOSA_OP_LIST[testName]["filter"] = k
4222 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07004223
Kevin Cheng550ccc52021-03-03 11:21:43 -08004224 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
4225 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
4226 "transpose_conv2d_TEMPLATE"
4227 ].copy()
4228 self.TOSA_OP_LIST[testName]["filter"] = k
4229 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07004230
Kevin Cheng1533b852021-09-01 12:51:58 -07004231 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
4232 for k in KERNELS_3D:
4233 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
4234 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
4235 self.TOSA_OP_LIST[testName]["filter"] = k
4236 self.TOSA_OP_LIST[testName]["template"] = False
4237
Eric Kunzee5e26762020-10-13 16:11:07 -07004238 # Delete any templates after having created any dynamic ops
4239 # This is a two-pass operation because it's bad practice to delete
4240 # keys from dictionaries while iterating
4241 keyList = []
4242 for k in self.TOSA_OP_LIST:
4243 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004244 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07004245 keyList.append(k)
4246 continue
4247 except KeyError:
4248 pass
4249
4250 for k in keyList:
4251 del self.TOSA_OP_LIST[k]
4252
4253 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004254 """Fill in default fields for ops if they aren't already specified.
4255 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07004256 for op in self.TOSA_OP_LIST:
4257
4258 # Required fields
4259 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004260 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004261 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004262 raise Exception(
4263 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
4264 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004265
4266 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004267 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004268 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004269 raise Exception(
4270 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
4271 op
4272 )
4273 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004274
4275 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004276 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004277 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004278 raise Exception(
4279 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
4280 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004281
4282 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004283 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004284 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004285 raise Exception(
4286 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
4287 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004288
4289 # Put in default rank range, if missing
4290 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004291 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004292 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004293 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07004294
4295 # Tensor operator list
4296 # 'op': op name
4297 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08004298 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
4299 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07004300 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
4301 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08004302 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07004303
Kevin Cheng550ccc52021-03-03 11:21:43 -08004304 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
4305 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07004306
Kevin Cheng550ccc52021-03-03 11:21:43 -08004307 TYPE_BOOL = [DType.BOOL]
4308 TYPE_FI32 = [DType.FLOAT, DType.INT32]
4309 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
4310 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07004311
Kevin Cheng550ccc52021-03-03 11:21:43 -08004312 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07004313
Kevin Cheng1533b852021-09-01 12:51:58 -07004314 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07004315 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07004316 [DType.INT8, DType.INT8, DType.INT32],
4317 [DType.INT16, DType.INT8, DType.INT48],
4318 DType.FLOAT,
4319 ]
4320
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01004321 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07004322
4323 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08004324 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004325 "argmax": {
4326 "op": Op.ARGMAX,
4327 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004328 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004329 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4330 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004331 "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evArgmaxOutputRankMismatch,
4332 TosaErrorValidator.evArgmaxOutputShapeMismatch, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
4333 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004334 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004335 "avg_pool2d": {
4336 "op": Op.AVG_POOL2D,
4337 "operands": (1, 0),
4338 "rank": (4, 4),
4339 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
4340 "qgen": TosaQuantGen.qgUnary,
4341 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004342 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
4343 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
4344 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
4345 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
4346 TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08004347 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004348 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08004349 "conv2d_TEMPLATE": {
4350 "op": Op.CONV2D,
4351 "operands": (1, 2),
4352 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01004353 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004354 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004355 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01004356 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004357 "template": True,
4358 },
Kevin Cheng1533b852021-09-01 12:51:58 -07004359 # Templated operator. Filled in by createDynamicOpLists
4360 "conv3d_TEMPLATE": {
4361 "op": Op.CONV3D,
4362 "operands": (1, 2),
4363 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01004364 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07004365 "qgen": TosaQuantGen.qgConv,
4366 "types": TYPE_CONV,
4367 "template": True,
4368 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004369 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08004370 "depthwise_conv2d_TEMPLATE": {
4371 "op": Op.DEPTHWISE_CONV2D,
4372 "operands": (1, 2),
4373 "filter": [1, 1],
4374 "rank": (4, 4),
4375 "build_fcn": (
4376 build_depthwise_conv2d,
4377 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01004378 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004379 ),
4380 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004381 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01004382 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004383 "template": True,
4384 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004385 "fully_connected": {
4386 "op": Op.FULLY_CONNECTED,
4387 "operands": (1, 2),
4388 "rank": (2, 2),
4389 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
4390 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004391 "types": TYPE_CONV,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004392 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evWrongRank,
4393 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004394 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004395 "matmul": {
4396 "op": Op.MATMUL,
4397 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07004398 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08004399 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
4400 "qgen": TosaQuantGen.qgMatmul,
4401 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004402 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
4403 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004404 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004405 "max_pool2d": {
4406 "op": Op.MAX_POOL2D,
4407 "operands": (1, 0),
4408 "rank": (4, 4),
4409 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
4410 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004411 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
4412 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
4413 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
4414 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08004415 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004416 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08004417 "transpose_conv2d_TEMPLATE": {
4418 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07004419 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004420 "rank": (4, 4),
4421 "build_fcn": (
4422 build_transpose_conv2d,
4423 TosaTensorGen.tgTransposeConv2D,
4424 TosaArgGen.agTransposeConv2D,
4425 ),
4426 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004427 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01004428 "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004429 "template": True,
4430 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004431 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08004432 "clamp": {
4433 "op": Op.CLAMP,
4434 "operands": (1, 0),
4435 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
4436 "types": TYPE_NARROW_INT_FP,
4437 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004438 "sigmoid": {
4439 "op": Op.SIGMOID,
4440 "operands": (1, 0),
4441 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
4442 "types": TYPE_FP,
4443 },
4444 "tanh": {
4445 "op": Op.TANH,
4446 "operands": (1, 0),
4447 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
4448 "types": TYPE_FP,
4449 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004450 # Elementwise Binary Operators
4451 "add": {
4452 "op": Op.ADD,
4453 "operands": (2, 0),
4454 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4455 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004456 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4457 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004458 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004459 "arithmetic_right_shift": {
4460 "op": Op.ARITHMETIC_RIGHT_SHIFT,
4461 "operands": (2, 0),
4462 "build_fcn": (
4463 build_arithmetic_right_shift,
4464 TosaTensorGen.tgBroadcastFuzz,
4465 TosaArgGen.agArithmeticRightShift,
4466 ),
4467 "types": TYPE_INT,
4468 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004469 "bitwise_and": {
4470 "op": Op.BITWISE_AND,
4471 "operands": (2, 0),
4472 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4473 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004474 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4475 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004476 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004477 "bitwise_or": {
4478 "op": Op.BITWISE_OR,
4479 "operands": (2, 0),
4480 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4481 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004482 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4483 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004484 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004485 "bitwise_xor": {
4486 "op": Op.BITWISE_XOR,
4487 "operands": (2, 0),
4488 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4489 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004490 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4491 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004492 },
Matthew Haddon459443c2021-08-23 16:43:13 +01004493 "intdiv": {
4494 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004495 "operands": (2, 0),
4496 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4497 "types": [DType.INT32],
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004498 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4499 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004500 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004501 "logical_and": {
4502 "op": Op.LOGICAL_AND,
4503 "operands": (2, 0),
4504 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4505 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004506 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4507 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004508 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004509 "logical_left_shift": {
4510 "op": Op.LOGICAL_LEFT_SHIFT,
4511 "operands": (2, 0),
4512 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4513 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004514 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4515 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004516 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004517 "logical_right_shift": {
4518 "op": Op.LOGICAL_RIGHT_SHIFT,
4519 "operands": (2, 0),
4520 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4521 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004522 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4523 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004524 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004525 "logical_or": {
4526 "op": Op.LOGICAL_OR,
4527 "operands": (2, 0),
4528 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4529 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004530 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4531 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004532 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004533 "logical_xor": {
4534 "op": Op.LOGICAL_XOR,
4535 "operands": (2, 0),
4536 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4537 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004538 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4539 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004540 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004541 "maximum": {
4542 "op": Op.MAXIMUM,
4543 "operands": (2, 0),
4544 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4545 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004546 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4547 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004548 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004549 "minimum": {
4550 "op": Op.MINIMUM,
4551 "operands": (2, 0),
4552 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4553 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004554 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4555 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004556 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004557 "mul": {
4558 "op": Op.MUL,
4559 "operands": (2, 0),
4560 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
4561 "types": TYPE_INT_FP,
4562 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004563 "pow": {
4564 "op": Op.POW,
4565 "operands": (2, 0),
4566 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
4567 "types": TYPE_FP,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004568 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4569 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004570 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004571 "sub": {
4572 "op": Op.SUB,
4573 "operands": (2, 0),
4574 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4575 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004576 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4577 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004578 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004579 "table": {
4580 "op": Op.TABLE,
4581 # Use the automatic generation functions to create the input array
4582 # but create the table tensor in the build function, as it may be
4583 # a different type from the input
4584 "operands": (1, 0),
Kevin Chengfe392ce2021-10-18 21:51:55 +00004585 "build_fcn": (build_table, TosaTensorGen.tgBasic, TosaArgGen.agTable),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004586 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08004587 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004588 # Elementwise Unary operators
4589 "abs": {
4590 "op": Op.ABS,
4591 "operands": (1, 0),
4592 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4593 "types": TYPE_FI32,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004594 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4595 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004596 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004597 "bitwise_not": {
4598 "op": Op.BITWISE_NOT,
4599 "operands": (1, 0),
4600 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4601 "types": TYPE_INT,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004602 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4603 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004604 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004605 "ceil": {
4606 "op": Op.CEIL,
4607 "operands": (1, 0),
4608 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4609 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004610 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4611 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004612 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004613 "clz": {
4614 "op": Op.CLZ,
4615 "operands": (1, 0),
4616 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4617 "types": [DType.INT32],
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004618 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4619 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004620 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004621 "exp": {
4622 "op": Op.EXP,
4623 "operands": (1, 0),
4624 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4625 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004626 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4627 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004628 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004629 "floor": {
4630 "op": Op.FLOOR,
4631 "operands": (1, 0),
4632 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4633 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004634 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4635 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004636 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004637 "log": {
4638 "op": Op.LOG,
4639 "operands": (1, 0),
4640 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4641 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004642 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4643 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004644 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004645 "logical_not": {
4646 "op": Op.LOGICAL_NOT,
4647 "operands": (1, 0),
4648 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4649 "types": TYPE_BOOL,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004650 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4651 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004652 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004653 "negate": {
4654 "op": Op.NEGATE,
4655 "operands": (1, 0),
4656 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4657 "qgen": TosaQuantGen.qgUnary,
4658 "types": TYPE_INT_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004659 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
4660 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
4661 TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004662 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004663 "reciprocal": {
4664 "op": Op.RECIPROCAL,
4665 "operands": (1, 0),
4666 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4667 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004668 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4669 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004670 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004671 "rsqrt": {
4672 "op": Op.RSQRT,
4673 "operands": (1, 0),
4674 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4675 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004676 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4677 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004678 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004679 # Elementwise Ternary operators
4680 "select": {
4681 "op": Op.SELECT,
4682 "operands": (3, 0),
4683 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
4684 "types": TYPE_FIB,
4685 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004686 # Comparison operators
4687 "equal": {
4688 "op": Op.EQUAL,
4689 "operands": (2, 0),
4690 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4691 "types": TYPE_FI32,
4692 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004693 "greater_equal": {
4694 "op": Op.GREATER_EQUAL,
4695 "operands": (2, 0),
4696 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4697 "types": TYPE_FI32,
4698 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004699 "greater": {
4700 "op": Op.GREATER,
4701 "operands": (2, 0),
4702 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4703 "types": TYPE_FI32,
4704 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004705 # Reduction operators
4706 "reduce_all": {
4707 "op": Op.REDUCE_ALL,
4708 "operands": (1, 0),
4709 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4710 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004711 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4712 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4713 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004714 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004715 "reduce_any": {
4716 "op": Op.REDUCE_ANY,
4717 "operands": (1, 0),
4718 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4719 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004720 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4721 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4722 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004723 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004724 "reduce_max": {
4725 "op": Op.REDUCE_MAX,
4726 "operands": (1, 0),
4727 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4728 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004729 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4730 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4731 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004732 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004733 "reduce_min": {
4734 "op": Op.REDUCE_MAX,
4735 "operands": (1, 0),
4736 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4737 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004738 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4739 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4740 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004741 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004742 "reduce_product": {
4743 "op": Op.REDUCE_PRODUCT,
4744 "operands": (1, 0),
4745 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4746 "types": TYPE_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004747 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4748 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4749 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004750 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004751 "reduce_sum": {
4752 "op": Op.REDUCE_SUM,
4753 "operands": (1, 0),
4754 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4755 "types": TYPE_FI32,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004756 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4757 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4758 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004759 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004760 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004761 "concat": {
4762 "op": Op.CONCAT,
4763 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01004764 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004765 "types": TYPE_FIB,
4766 },
4767 "pad": {
4768 "op": Op.PAD,
4769 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01004770 "rank": (1, 5),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004771 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
4772 "qgen": TosaQuantGen.qgPad,
4773 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004774 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evPadSmallerZero,
4775 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004776 },
4777 "reshape": {
4778 "op": Op.RESHAPE,
4779 "operands": (1, 0),
4780 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
4781 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004782 "error_if_validators": (TosaErrorValidator.evTensorSizeInputOutputMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4783 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004784 },
4785 "reverse": {
4786 "op": Op.REVERSE,
4787 "operands": (1, 0),
4788 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4789 "types": TYPE_FIB,
4790 },
4791 "slice": {
4792 "op": Op.SLICE,
4793 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01004794 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004795 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
4796 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004797 "error_if_validators": (TosaErrorValidator.evStartSmallerZero, TosaErrorValidator.evSizeSmallerEqualZero, TosaErrorValidator.evStartSizeOutsideBounds,
4798 TosaErrorValidator.evSizeOutputShapeMismatch, TosaErrorValidator.evInputSizeStartLengthMismatch, TosaErrorValidator.evWrongRank,
4799 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004800 },
4801 "tile": {
4802 "op": Op.TILE,
4803 "operands": (1, 0),
4804 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
4805 "types": TYPE_FIB,
4806 },
4807 "transpose": {
4808 "op": Op.TRANSPOSE,
4809 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01004810 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004811 "build_fcn": (
4812 build_transpose,
4813 TosaTensorGen.tgBasic,
4814 TosaArgGen.agTranspose,
4815 ),
4816 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004817 "error_if_validators": (TosaErrorValidator.evIndexOutsideBounds, TosaErrorValidator.evIndexUsedTwice, TosaErrorValidator.evWrongRank,
4818 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004819 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004820 # Data nodes
4821 "const": {
4822 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004823 "operands": (0, 1),
4824 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08004825 "types": TYPE_FIB,
4826 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004827 "identity": {
4828 "op": Op.IDENTITY,
4829 "operands": (1, 0),
4830 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4831 "types": TYPE_FIB,
4832 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004833 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004834 "gather": {
4835 "op": Op.GATHER,
4836 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4837 "operands": (1, 0),
4838 "rank": (3, 3),
4839 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
4840 "types": TYPE_INT_FP,
4841 },
4842 "scatter": {
4843 "op": Op.SCATTER,
4844 # Only specify 'values_in' tensor here.
4845 #'indices' and 'input' are generated in op building stage
4846 "operands": (2, 0),
4847 "rank": (3, 3),
4848 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
4849 "types": TYPE_INT_FP,
4850 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004851 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004852 "resize": {
4853 "op": Op.RESIZE,
4854 "operands": (1, 0),
4855 "rank": (4, 4),
4856 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
4857 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01004858 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
4859 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
4860 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
Matthew Haddon848efb42021-09-09 12:30:53 +01004861 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004862 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
4863 TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004864 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004865 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004866 "cast": {
4867 "op": Op.CAST,
4868 "operands": (1, 0),
4869 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
4870 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
4871 },
4872 "rescale": {
4873 "op": Op.RESCALE,
4874 "operands": (1, 0),
Matthew Haddonc2025212021-10-08 21:21:05 +01004875 "rank": (1,4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004876 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004877 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Matthew Haddonc2025212021-10-08 21:21:05 +01004878 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, TosaErrorValidator.evScaleTrue,
4879 TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4880 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004881 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004882 # Custom
4883 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004884 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004885 # Two varients of cond_if, one that generates one of two constant tensors (no
4886 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4887 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004888 "cond_if_const": {
4889 "op": Op.COND_IF,
4890 "operands": (0, 2),
4891 "build_fcn": (
4892 build_cond_if_const,
4893 TosaTensorGen.tgBasic,
4894 TosaArgGen.agCondIf,
4895 ),
4896 "types": [DType.BOOL],
4897 },
4898 "cond_if_binary": {
4899 "op": Op.COND_IF,
4900 "operands": (2, 0),
4901 "build_fcn": (
4902 build_cond_if_binary,
4903 TosaTensorGen.tgBasic,
4904 TosaArgGen.agCondIf,
4905 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004906 "types": TYPE_INT_FP,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004907 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004908 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004909 "while_loop": {
4910 "op": Op.WHILE_LOOP,
4911 "operands": (0, 1),
4912 "build_fcn": (
4913 build_while_loop,
4914 TosaTensorGen.tgBasic,
4915 TosaArgGen.agWhileLoop,
4916 ),
4917 "types": [DType.INT32],
4918 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004919 }
4920
Kevin Cheng550ccc52021-03-03 11:21:43 -08004921
Eric Kunzee5e26762020-10-13 16:11:07 -07004922class OutputShaper:
4923 # Methods in this class compute the expected output shape and datatype
4924 # for common classes of operations
4925 def __init__(self):
4926 pass
4927
4928 # These methods return arguments that can be used for
4929 # creating a new output tensor
4930 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004931 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4932 if error_name != ErrorIf.RankMismatch:
4933 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004934 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004935
4936 shape = []
4937 for i in range(len(a.shape)):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004938 if a.shape[i] == 1 and error_name == None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004939 shape.append(b.shape[i])
4940 else:
4941 shape.append(a.shape[i])
4942
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004943 if error_name == ErrorIf.WrongOutputType:
4944 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4945 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4946 outputDType = rng.choice(wrong_dtypes)
4947 else:
4948 outputDType = a.dtype
4949
4950 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004951
4952 @staticmethod
4953 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004954 assert len(a.shape) == len(b.shape)
4955 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004956
4957 shape = []
4958 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004959 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004960 shape.append(a.shape[i])
4961
Kevin Cheng550ccc52021-03-03 11:21:43 -08004962 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004963
4964 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004965 def unaryOp(ser, rng, a, error_name=None):
4966 if error_name == ErrorIf.WrongOutputType:
4967 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4968 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4969 outputDType = rng.choice(wrong_dtypes)
4970 else:
4971 outputDType = a.dtype
4972
4973 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004974
4975 @staticmethod
4976 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004977 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
4978 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004979
4980 shape = []
4981 for i in range(len(a.shape)):
4982 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4983
Kevin Cheng550ccc52021-03-03 11:21:43 -08004984 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004985
4986 @staticmethod
4987 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004988 assert len(a.shape) == len(b.shape)
4989 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004990
4991 # Do broadcast
4992 shape = []
4993 for i in range(len(a.shape)):
4994 if a.shape[i] == 1:
4995 shape.append(b.shape[i])
4996 else:
4997 shape.append(a.shape[i])
4998
4999 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08005000 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07005001
5002 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005003 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005004 shape = a.shape.copy()
Matthew Haddond6ce7252021-09-29 15:35:44 +01005005 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ShapeOfAxisNotOne]:
5006 shape[axis] = 1
5007 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5008 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005009
Matthew Haddond6ce7252021-09-29 15:35:44 +01005010 if error_name == ErrorIf.WrongOutputType:
5011 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5012 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5013 outputDType = rng.choice(wrong_dtypes)
5014 else:
5015 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005016
Matthew Haddond6ce7252021-09-29 15:35:44 +01005017 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005018
5019 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005020 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005021 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005022
5023 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5024 del shape[axis]
5025
5026 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5027 remove = rng.choice([True, False])
5028 if remove and len(shape) > 1:
5029 del shape[0]
5030 else:
5031 shape.append(1)
5032 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5033 for i in range(len(shape)):
5034 shape[i] = shape[i] + rng.integers(1, 10)
5035
5036 if error_name == ErrorIf.WrongOutputType:
5037 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5038 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5039 outputDType = rng.choice(wrong_dtypes)
5040 else:
5041 outputDType = DType.INT32
5042
5043 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005044
5045 @staticmethod
5046 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
5047
5048 # IFM: NHWC
5049 # Filter: OHWI
5050 # OFM: NHWC
5051
5052 if len(padding) == 2:
5053 # Expand padding to 4 parameters in the case of transpose_conv2d
5054 # From H,W to T,B,L,R
5055 padding = [padding[0], padding[0], padding[1], padding[1]]
5056
Kevin Cheng550ccc52021-03-03 11:21:43 -08005057 h = (
5058 ifm.shape[1]
5059 - filter.shape[1]
5060 - (filter.shape[1] - 1) * (dilations[0] - 1)
5061 + padding[0]
5062 + padding[1]
5063 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005064
Kevin Cheng550ccc52021-03-03 11:21:43 -08005065 w = (
5066 ifm.shape[2]
5067 - filter.shape[2]
5068 - (filter.shape[2] - 1) * (dilations[1] - 1)
5069 + padding[2]
5070 + padding[3]
5071 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005072
Eric Kunzee5e26762020-10-13 16:11:07 -07005073 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5074
Kevin Cheng3a478572021-01-22 17:21:02 -08005075 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005076 out_dtype = DType.INT32
5077 elif ifm.dtype == DType.INT16:
5078 out_dtype = DType.INT48
5079 elif ifm.dtype == DType.FLOAT:
5080 out_dtype = DType.FLOAT
5081 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005082 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005083
Kevin Cheng550ccc52021-03-03 11:21:43 -08005084 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005085
5086 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -07005087 def conv3dOp(ser, ifm, filter, strides, padding, dilations):
5088
5089 # IFM: NDHWC
5090 # Filter: ODHWI
5091 # OFM: NDHWC
5092
5093 d = (
5094 ifm.shape[1]
5095 - filter.shape[1]
5096 - (filter.shape[1] - 1) * (dilations[0] - 1)
5097 + padding[0]
5098 + padding[1]
5099 ) // strides[0] + 1
5100
5101 h = (
5102 ifm.shape[2]
5103 - filter.shape[2]
5104 - (filter.shape[2] - 1) * (dilations[1] - 1)
5105 + padding[2]
5106 + padding[3]
5107 ) // strides[1] + 1
5108
5109 w = (
5110 ifm.shape[3]
5111 - filter.shape[3]
5112 - (filter.shape[3] - 1) * (dilations[2] - 1)
5113 + padding[4]
5114 + padding[5]
5115 ) // strides[2] + 1
5116
5117 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5118
5119 if ifm.dtype == DType.INT8:
5120 out_dtype = DType.INT32
5121 elif ifm.dtype == DType.INT16:
5122 out_dtype = DType.INT48
5123 elif ifm.dtype == DType.FLOAT:
5124 out_dtype = DType.FLOAT
5125 else:
5126 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
5127
5128 return ser.addOutput(ofm_shape, out_dtype)
5129
5130 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -07005131 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
5132 # IFM: NHWC
5133 # Filter: HWCM
5134 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08005135 h = (
5136 ifm.shape[1]
5137 - filter.shape[0]
5138 - (filter.shape[0] - 1) * (dilations[0] - 1)
5139 + padding[0]
5140 + padding[1]
5141 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005142
Kevin Cheng550ccc52021-03-03 11:21:43 -08005143 w = (
5144 ifm.shape[2]
5145 - filter.shape[1]
5146 - (filter.shape[1] - 1) * (dilations[1] - 1)
5147 + padding[2]
5148 + padding[3]
5149 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005150
Eric Kunzee5e26762020-10-13 16:11:07 -07005151 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5152
Kevin Cheng3a478572021-01-22 17:21:02 -08005153 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005154 out_dtype = DType.INT32
5155 elif ifm.dtype == DType.INT16:
5156 out_dtype = DType.INT48
5157 elif ifm.dtype == DType.FLOAT:
5158 out_dtype = DType.FLOAT
5159 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005160 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005161
Kevin Cheng550ccc52021-03-03 11:21:43 -08005162 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005163
5164 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005165 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005166 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005167 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005168 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005169 h = 1
5170 w = 1
5171 else:
5172 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
5173 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
5174
5175 if error_name == ErrorIf.PoolingOutputShapeMismatch:
5176 choices = [1, 2, 3, 4, 5]
5177 h = h + rng.choice(choices)
5178 w = w + rng.choice(choices)
Eric Kunzee5e26762020-10-13 16:11:07 -07005179
Eric Kunzee5e26762020-10-13 16:11:07 -07005180 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005181
5182 if error_name == ErrorIf.WrongOutputType:
5183 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5184 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5185 outputDType = rng.choice(wrong_dtypes)
5186 else:
5187 outputDType = ifm.dtype
5188
5189 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005190
5191 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005192 def fullyConnectedOp(ser, rng, input, filter, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005193 # input: N, IC
5194 # filter: OC, IC
5195 # output: N, OC
5196
5197 output_shape = [input.shape[0], filter.shape[0]]
5198
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005199 if error_name == ErrorIf.WrongOutputType:
5200 if input.dtype == DType.INT8:
5201 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
5202 elif input.dtype == DType.INT16:
5203 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
5204 elif input.dtype == DType.FLOAT:
5205 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
5206 out_dtype = rng.choice(a=incorrect_types)
5207 elif input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005208 out_dtype = DType.INT32
5209 elif input.dtype == DType.INT16:
5210 out_dtype = DType.INT48
5211 elif input.dtype == DType.FLOAT:
5212 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005213 elif error_name == ErrorIf.WrongInputType:
5214 # Pick some potentially correct output dtype if input type is incorrect
5215 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005216 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005217 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005218
Kevin Cheng550ccc52021-03-03 11:21:43 -08005219 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005220
5221 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005222 def matmulOp(ser, rng, a, b, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005223 # a: N, H, C
5224 # b: N, C, W
5225 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005226
Kevin Cheng2d60f002021-06-09 14:18:32 -07005227 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005228
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005229 if error_name == ErrorIf.WrongOutputType:
5230 if a.dtype == DType.INT8:
5231 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
5232 elif a.dtype == DType.INT16:
5233 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
5234 elif a.dtype == DType.FLOAT:
5235 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
5236 out_dtype = rng.choice(a=incorrect_types)
5237 elif a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005238 out_dtype = DType.INT32
5239 elif a.dtype == DType.INT16:
5240 out_dtype = DType.INT48
5241 elif a.dtype == DType.FLOAT:
5242 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005243 elif error_name == ErrorIf.WrongInputType:
5244 # Pick some potentially correct output dtype if input type is incorrect
5245 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005246 else:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005247 raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005248
Kevin Cheng550ccc52021-03-03 11:21:43 -08005249 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005250
5251 @staticmethod
Matthew Haddon818ab902021-07-27 09:12:49 +01005252 def concatOp(ser, axis, *a):
5253 input1 = a[0]
5254 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005255
Matthew Haddon818ab902021-07-27 09:12:49 +01005256 output_shape = input1.shape.copy()
Eric Kunzee5e26762020-10-13 16:11:07 -07005257
Matthew Haddon818ab902021-07-27 09:12:49 +01005258 output_shape[axis] = input1.shape[axis]
5259
5260 for tensor in remaining_inputs:
5261 output_shape[axis] += tensor.shape[axis]
5262
5263 return ser.addOutput(output_shape, input1.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005264
5265 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005266 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005267
5268 output_shape = a.shape.copy()
5269
5270 for i in range(len(output_shape)):
5271 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5272
Matthew Haddone807aae2021-10-11 18:12:58 +01005273 # Fix negative output shape if error_if test causes it
5274 if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
5275 output_shape = [i if i >= 1 else 1 for i in output_shape]
5276
5277 if error_name == ErrorIf.WrongOutputType:
5278 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5279 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5280 outputDType = rng.choice(wrong_dtypes)
5281 else:
5282 outputDType = a.dtype
5283
5284 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005285
5286 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005287 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005288 output_shape = shape.copy()
5289
5290 totalElements = 1
5291 for i in a.shape:
5292 totalElements *= i
5293
5294 # If there are any -1 elements, figure out what that dimension must be
5295 totalOutputElements = 1
5296 for i in output_shape:
5297 if i != -1:
5298 totalOutputElements *= i
5299
5300 # And fill it in
5301 for i in range(len(output_shape)):
5302 if output_shape[i] == -1:
5303 output_shape[i] = totalElements // totalOutputElements
5304
Matthew Haddone807aae2021-10-11 18:12:58 +01005305 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5306 for i in range(len(output_shape)):
5307 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5308
5309 if error_name == ErrorIf.WrongOutputType:
5310 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5311 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5312 outputDType = rng.choice(wrong_dtypes)
5313 else:
5314 outputDType = a.dtype
5315
5316 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005317
5318 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005319 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005320
Matthew Haddone807aae2021-10-11 18:12:58 +01005321 if error_name == ErrorIf.WrongOutputType:
5322 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5323 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5324 outputDType = rng.choice(wrong_dtypes)
5325 else:
5326 outputDType = a.dtype
5327
5328 if error_name == ErrorIf.SizeOutputShapeMismatch:
5329 output_shape = size.copy()
5330 for index in range(len(output_shape)):
5331 if output_shape[index] <= 2:
5332 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5333 else:
5334 output_shape[index] = output_shape[index] + rng.choice([-2, -1, 1, 2])
5335 else:
5336 output_shape = size.copy()
5337
5338 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005339
5340 @staticmethod
5341 def tileOp(ser, a, multiples):
5342
5343 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005344 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005345
5346 for i in range(len(output_shape)):
5347 output_shape[i] = a.shape[i] * multiples[i]
5348
Kevin Cheng550ccc52021-03-03 11:21:43 -08005349 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005350
5351 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005352 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005353 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005354
Kevin Cheng550ccc52021-03-03 11:21:43 -08005355 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005356
Matthew Haddone807aae2021-10-11 18:12:58 +01005357 if error_name == ErrorIf.IndexOutsideBounds:
5358 for i in range(len(output_shape)):
5359 output_shape[i] = a.shape[0]
5360 else:
5361 for i in range(len(output_shape)):
5362 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005363
Matthew Haddone807aae2021-10-11 18:12:58 +01005364 if error_name == ErrorIf.WrongOutputType:
5365 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5366 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5367 outputDType = rng.choice(wrong_dtypes)
5368 else:
5369 outputDType = a.dtype
5370
5371 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005372
5373 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08005374 def gatherOp(ser, values, indices):
5375 assert len(values.shape) == 3
5376 assert len(indices.shape) == 2
5377 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005378
Kevin Cheng77d0f762020-11-24 10:26:32 -08005379 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5380
Kevin Cheng550ccc52021-03-03 11:21:43 -08005381 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005382
5383 @staticmethod
5384 def scatterOp(ser, values_in, indices, input):
5385 assert len(values_in.shape) == 3
5386 assert len(indices.shape) == 2
5387 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005388 assert values_in.shape[0] == indices.shape[0] # N
5389 assert input.shape[1] == indices.shape[1] # W
5390 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005391
5392 output_shape = values_in.shape
5393
Kevin Cheng550ccc52021-03-03 11:21:43 -08005394 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005395
5396 @staticmethod
Kevin Chengfe392ce2021-10-18 21:51:55 +00005397 def tableOp(ser, input):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005398 # Same shape as the input, but dtype dependent on table dtype
Kevin Chengfe392ce2021-10-18 21:51:55 +00005399 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
5400 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005401 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005402
5403 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005404 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005405 serializer,
5406 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005407 input,
5408 mode,
5409 stride,
5410 offset,
5411 shift,
5412 stride_fp,
5413 offset_fp,
5414 output_dims,
5415 input_dtype,
5416 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01005417 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08005418 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01005419 if error_name == ErrorIf.WrongRank:
5420 output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
5421 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005422 if error_name == ErrorIf.BatchMismatch:
5423 output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
5424 elif error_name == ErrorIf.ChannelMismatch:
5425 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
5426 else:
5427 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005428
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005429 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005430
5431 @staticmethod
5432 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005433 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005434
5435 @staticmethod
5436 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08005437 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005438 out_dtype = DType.INT32
5439 elif ifm.dtype == DType.INT16:
5440 out_dtype = DType.INT48
5441 elif ifm.dtype == DType.FLOAT:
5442 out_dtype = DType.FLOAT
5443 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005444 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005445
Kevin Cheng550ccc52021-03-03 11:21:43 -08005446 return ser.addOutput(output_shape, out_dtype)