blob: 04fce90003c1386e9ab6ef253a28264a3fb4eaea [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
48DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080049Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070050ResizeMode = tosa.ResizeMode.ResizeMode()
51
52class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
54
Eric Kunzee5e26762020-10-13 16:11:07 -070055 def __init__(self):
56 pass
57
58 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010059 def getQinfo(testGen, dtype, error_name=None):
60
Les Bell30e46802021-07-23 09:43:31 +010061 if dtype == DType.INT8:
62 return testGen.randInt(-128, 128)
Matthew Haddone4ecdb22021-09-28 11:38:21 +010063 elif dtype == DType.UINT8:
Les Bell30e46802021-07-23 09:43:31 +010064 return testGen.randInt(0, 256)
Matthew Haddonc4cf0372021-10-11 09:38:10 +010065 elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.WeightZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
Matthew Haddone4ecdb22021-09-28 11:38:21 +010066 zero_point = testGen.randInt(-128, 128)
67 if zero_point == 0:
68 zero_point = 1
69 return zero_point
Les Bell30e46802021-07-23 09:43:31 +010070 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070071
72 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010073 def qgUnary(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070074 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone4ecdb22021-09-28 11:38:21 +010075 if error_name == ErrorIf.InputZeroPointNotZero:
76 qinfo.UnaryQuantInfo(
77 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype)
78 )
79 elif error_name == ErrorIf.OutputZeroPointNotZero:
80 qinfo.UnaryQuantInfo(
81 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype, error_name)
82 )
83 else:
84 qinfo.UnaryQuantInfo(
85 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
86 )
Eric Kunzee5e26762020-10-13 16:11:07 -070087 return qinfo
88
89 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010090 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070091 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010092 if isinstance(dtype_or_dtypeList, list):
93 # a list of [input, weights, accumulator] dtypes
94 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070095 else:
Les Bell30e46802021-07-23 09:43:31 +010096 # an int, [input, weights, accumulator] dtypes are the same
97 dtypeList = [dtype_or_dtypeList] * 3
Matthew Haddonc4cf0372021-10-11 09:38:10 +010098
99 if error_name == ErrorIf.InputZeroPointNotZero:
100 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
101 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
102 elif error_name == ErrorIf.WeightZeroPointNotZero:
103 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
104 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
105 else:
106 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
107 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
108
Les Bell30e46802021-07-23 09:43:31 +0100109 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -0700110 return qinfo
111
112 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100113 def qgMatmul(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700114 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100115 if error_name == ErrorIf.InputZeroPointNotZero:
116 qinfo.MatMulQuantInfo(
117 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype, error_name)
Kevin Chengacb550f2021-06-29 15:32:19 -0700118 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100119 else:
120 qinfo.MatMulQuantInfo(
121 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
122 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700123 return qinfo
124
125 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100126 def qgPad(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700127 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone807aae2021-10-11 18:12:58 +0100128 if error_name == ErrorIf.InputZeroPointNotZero:
129 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name))
130 else:
131 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700132 return qinfo
133
134 @staticmethod
135 def computeMultiplierAndShift(scaleFp, scale32):
136 # Derived from computeMultiplierAndShiftTosaScale32
137 # Provide a floating-point scaling factor and the scale32 parameter
138 # to compute the multiplier and shift
139
140 if scale32:
141 scaleBits = 31
142 else:
143 scaleBits = 15
144
145 m, shift = math.frexp(scaleFp)
146
147 if scaleFp < 0.0:
148 m = -m
149
150 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800151 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700152
153 if multiplier == (1 << scaleBits):
154 multiplier = multiplier // 2
155 shift = shift + 1
156
157 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100158 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
159
160 # Adjust multiplier such that shift is in allowed value range.
161 if shift == 0:
162 multiplier = multiplier // 4
163 shift = shift + 2
164 elif shift == 1:
165 multiplier = multiplier // 2
166 shift = shift + 1
167 elif shift == 63:
168 multiplier = multiplier * 2
169 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700170
Kevin Cheng550ccc52021-03-03 11:21:43 -0800171 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100172 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700173
174 return multiplier, shift
175
176
Kevin Cheng550ccc52021-03-03 11:21:43 -0800177class TosaTensorGen:
178 """Tensor generators create a shape list for the placeholder and const tensor
179 data operands for the operator. The actual random data is generated separately for each test."""
180
Eric Kunzee5e26762020-10-13 16:11:07 -0700181 def __init__(self):
182 pass
183
184 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100185 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800186 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700187 shape = testGen.makeShape(rank)
188
Matthew Haddonc2025212021-10-08 21:21:05 +0100189 # Constrict dimension size for large ranks when creating WrongRank tests
190 shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 shape_list = []
193 for i in range(pl + const):
194 shape_list.append(shape.copy())
195
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100196 if error_name == ErrorIf.RankMismatch:
197 if rank == 1 and i != 1:
198 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
199 elif i != 1:
200 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
201
Eric Kunzee5e26762020-10-13 16:11:07 -0700202 return shape_list
203
204 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100205 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800206 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700207
Matthew Haddon848efb42021-09-09 12:30:53 +0100208 if error_name != ErrorIf.WrongRank:
209 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700210
211 shape = testGen.makeShape(rank)
212
213 # Constrict the batch size?
214 if testGen.args.max_batch_size:
215 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100216
217 # Constrict dimension size for large ranks when creating WrongRank tests
218 shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700219
220 shape_list = []
221 for i in range(pl + const):
222 shape_list.append(shape.copy())
223
224 return shape_list
225
226 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100227 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800228 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800229
Kevin Cheng550ccc52021-03-03 11:21:43 -0800230 assert pl == 2
231 assert const == 0
232 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800233
234 values_in_shape = testGen.makeShape(rank)
235
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100236 # ignore max batch size if target shape is set
237 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800238 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
239
Kevin Cheng550ccc52021-03-03 11:21:43 -0800240 W = testGen.randInt(
241 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
242 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100243 # Constrict W if one dimension is too large to keep tensor size reasonable
244 if max(values_in_shape) > 5000:
245 W = testGen.randInt(0, 16)
246
Kevin Cheng77d0f762020-11-24 10:26:32 -0800247 input_shape = [values_in_shape[0], W, values_in_shape[2]]
248
249 shape_list = []
250 shape_list.append(values_in_shape.copy())
251 shape_list.append(input_shape.copy())
252
253 return shape_list
254
255 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100256 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700257 shape = testGen.makeShape(rank)
258
Kevin Cheng550ccc52021-03-03 11:21:43 -0800259 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700260
261 shape_list = []
262
263 # Choose one of the inputs to broadcast
264 bcast_idx = testGen.randInt(0, pl + const)
265 for i in range(pl + const):
266 shape_bcast = shape.copy()
267
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100268 if error_name == ErrorIf.RankMismatch:
269 bcast_idx = -1 # Turn off broadcast because we are not testing it
270 if rank == 1 and i != 1:
271 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
272 elif i != 1:
273 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
274
Eric Kunzee5e26762020-10-13 16:11:07 -0700275 # If the chosen input, pick a random index to broadcast
276 if i == bcast_idx:
277 fuzz_idx = testGen.randInt(0, rank)
278 shape_bcast[fuzz_idx] = 1
279
280 shape_list.append(shape_bcast)
281
282 return shape_list
283
284 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100285 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800286 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700287
Kevin Cheng550ccc52021-03-03 11:21:43 -0800288 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700289
290 # IFM dimensions are NHWC
291 ifm_shape = testGen.makeShape(rank)
292
293 # Constrict the batch size?
294 if testGen.args.max_batch_size:
295 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
296
297 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800298 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700299
300 # Generate a random OFM depth
301 ofm_depth = testGen.makeShape(1)[0]
302
303 # The filter dimensions are OHWI
304 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
305
306 # The bias is OC
307 bias_shape = np.asarray([ofm_depth])
308
309 return [ifm_shape, filter_shape, bias_shape]
310
311 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100312 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700313 pl, const = op["operands"]
314
315 assert rank == 5
316
317 # IFM dimensions are NDHWC
318 ifm_shape = testGen.makeShape(rank)
319
320 # Constrict the batch size?
321 if testGen.args.max_batch_size:
322 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
323
324 # Get the filter depth/height/width from the operator parameters
325 filter_dhw = op["filter"]
326
327 # Generate a random OFM channel
328 ofm_channel = testGen.makeShape(1)[0]
329
330 # The filter dimensions are ODHWI
331 filter_shape = np.asarray(
332 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
333 )
334
335 # The bias is OC
336 bias_shape = np.asarray([ofm_channel])
337
338 return [ifm_shape, filter_shape, bias_shape]
339
340 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100341 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800342 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700343
Kevin Cheng550ccc52021-03-03 11:21:43 -0800344 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700345
346 # IFM dimensions are NHWC
347 ifm_shape = testGen.makeShape(rank)
348
349 # Constrict the batch size?
350 if testGen.args.max_batch_size:
351 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
352
353 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800354 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700355
356 # Generate a random OFM depth
357 ofm_depth = testGen.makeShape(1)[0]
358
359 # The filter dimensions are OHWI
360 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
361
Kevin Cheng989cb052021-04-28 16:29:44 -0700362 # The bias is OC
363 bias_shape = np.asarray([ofm_depth])
364
365 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700366
367 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100368 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800369 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700370
Kevin Cheng550ccc52021-03-03 11:21:43 -0800371 assert rank == 4
372 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700373
374 # IFM dimensions are NHWC
375 ifm_shape = testGen.makeShape(rank)
376
377 # Constrict the batch size?
378 if testGen.args.max_batch_size:
379 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
380
381 # Get the filter height/width from the operator parameters
382 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800383 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700384
385 # Generate a random OFM depth, but don't let it get too big because
386 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800387 filter_m = (
388 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
389 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700390
391 # The filter dimensions are HWCM
392 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
393
394 # The bias is M * C
395 bias_shape = np.asarray([ifm_shape[3] * filter_m])
396
397 return [ifm_shape, filter_shape, bias_shape]
398
399 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100400 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800401 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700402
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100403 if error_name != ErrorIf.WrongRank:
404 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700405
406 input_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100407
408 # Constrict dimension size for large ranks when creating WrongRank tests
409 shape = TosaErrorIfArgGen.eiRestrictDimension(input_shape, error_name)
410
Kevin Chengacb550f2021-06-29 15:32:19 -0700411 filter_oc = testGen.rng.integers(
412 low=testGen.args.tensor_shape_range[0],
413 high=testGen.args.tensor_shape_range[1],
414 size=1,
415 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700416 filter_shape = np.asarray([filter_oc, input_shape[1]])
417
418 bias_shape = np.asarray([filter_oc])
419
420 return [input_shape, filter_shape, bias_shape]
421
422 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100423 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800424 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700425
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100426 if error_name != ErrorIf.WrongRank:
427 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800428 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700429
430 a_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100431
432 # Constrict dimension size for large ranks when creating WrongRank tests
433 shape = TosaErrorIfArgGen.eiRestrictDimension(a_shape, error_name)
434
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100435 # Get a random number for b_oc even if target shape is defined
436 b_oc = np.int32(
437 testGen.rng.integers(
438 low=testGen.args.tensor_shape_range[0],
439 high=testGen.args.tensor_shape_range[1],
440 size=1,
441 )
442 )[0]
443 # If N or H is large let b_oc be 1 to reduce output tensor size
444 if max(a_shape) > 1000:
445 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700446
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100447 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700448 return [a_shape, b_shape]
449
Matthew Haddon818ab902021-07-27 09:12:49 +0100450 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100451 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100452 pl, const = opName["operands"]
453 shape = testGen.makeShape(rank)
454
455 # Create extra tensors to concat.
456 # Take into account value of pl when getting maximum number of concats
457 num_tensors = testGen.randInt(0, 4)
458 shape_list = []
459 for i in range(pl + const + num_tensors):
460 shape_list.append(shape.copy())
461
462 return shape_list
463
464 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100465 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100466 # Split concat shape along axis to allow for multiple const inputs
467 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100468 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddon818ab902021-07-27 09:12:49 +0100469 return shapeList
470
Jeremy Johnson960985a2021-10-06 10:58:14 +0100471 # Create copy of shape we are going to split (so we don't alter shapeList)
472 shape = shapeList[0].copy()
473 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100474 new_shapeList = [shape.copy()]
475 length_on_axis = shape[axis]
476 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700477 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100478 # Calculate split on axis and remaining value
479 split_shape_val = int(shape[axis] / 2)
480 remaining_length = remaining_length - split_shape_val
481
482 # Append new shape, and set remaining shape
483 shape[axis] = split_shape_val
484 new_shapeList.append(shape.copy())
485 shape[axis] = remaining_length
486 if i == len(shapeList) - 3:
487 new_shapeList.append(shape.copy())
488
489 return new_shapeList
490
491
Eric Kunzee5e26762020-10-13 16:11:07 -0700492class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800493 """Argument generators create exhaustive or random lists of attributes for operators that take
494 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
495 tuples where the descriptive_name is appended to the test name and the arglist is expanded
496 as arguments to the operator build function."""
497
Eric Kunzee5e26762020-10-13 16:11:07 -0700498 def __init__(self):
499 pass
500
501 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100502 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800503 """A trivial argument generator for operators that don't take any
504 non-tensor arguments"""
505 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700506
507 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100508 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800509 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700510 axes = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700511 shape = shapeList[0]
512
Matthew Haddond6ce7252021-09-29 15:35:44 +0100513 if error_name == ErrorIf.AxisSmallerZero:
514 small_axis = testGen.rng.integers(-5, 0)
515 axes.append(("axis{}".format(small_axis), [small_axis]))
516 elif error_name == ErrorIf.AxisLargerRank:
517 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
518 axes.append(("axis{}".format(large_axis), [large_axis]))
519 else:
520 for a in range(0, len(shape)):
521 axes.append(("axis{}".format(a), [a]))
522
Eric Kunzee5e26762020-10-13 16:11:07 -0700523 return axes
524
525 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100526 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700527 arg_list = []
528
529 ifm_shape = shapeList[0]
530 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100531 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
532 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700533
Les Bell7aa69f42021-09-20 10:44:07 +0100534 # Check the rank
535 rank = 5 if opName.startswith("conv3d") else 4
536 assert len(ifm_shape) == rank
537 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700538
Les Bell7aa69f42021-09-20 10:44:07 +0100539 # kernel rank omits batch and channels
540 k_rank = rank - 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700541
Les Bell7aa69f42021-09-20 10:44:07 +0100542 # Generate comprehensive argument lists
543 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
544 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
545 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
546 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
547 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
548 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700549
Les Bell7aa69f42021-09-20 10:44:07 +0100550 # add some oversize argument values
551 if max(ifm_shape) < 64:
552 bigPadding = 9
553 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
554 bigStride = 8
555 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
556 bigDilation = 7
557 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100558
559 # There are too many parameter combinations, so generate them sparsely
Les Bell7aa69f42021-09-20 10:44:07 +0100560 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
561 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
562 if sparsity < 13:
563 sparsity = 1
564 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
565 sparsity += 1
Les Bellf414b3c2021-09-06 11:29:46 +0100566 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100567 for s in sorted(list(strides)):
568 for p in sorted(list(paddings)):
569 for d in sorted(list(dilations)):
570 if (n % sparsity == 0
571 # padding must not exceed the kernel size ?
572 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
573 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
574 # the padded shape must exceed the kernel size
575 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
576 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
577 # the padded shape must exceed the dilation
578 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
579 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
580 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100581 arg_list.append(
582 (
583 "st{}_pad{}_dilat{}".format(
584 "".join([str(x) for x in s]),
585 "".join([str(x) for x in p]),
586 "".join([str(x) for x in d]),
587 ),
588 [s, p, d],
589 )
590 )
591 n += 1
592
Kevin Cheng1533b852021-09-01 12:51:58 -0700593 return arg_list
594
595 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100596 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700597 arg_list = []
598
599 ifm_shape = shapeList[0]
600 filter_shape = shapeList[1]
601
602 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800603 assert len(ifm_shape) == 4
604 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700605
Les Bell7aa69f42021-09-20 10:44:07 +0100606 # Generate comprehensive argument lists
607 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
608 paddings = {x for x in itertools.product(*([p_vals] * 2))}
609 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
610 strides = {x for x in itertools.product(*([s_vals] * 2))}
611 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
612 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700613
Les Bell7aa69f42021-09-20 10:44:07 +0100614 # add some oversize argument values
615 if max(ifm_shape) < 64:
616 bigPadding = 9
617 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
618 bigStride = 8
619 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
620 bigDilation = 7
621 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700622
Les Bell7aa69f42021-09-20 10:44:07 +0100623 # There are too many parameter combinations, so generate them sparsely
624 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
625 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
626 if sparsity < 13:
627 sparsity = 1
628 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
629 sparsity += 1
630 n = 0
631 for s in sorted(list(strides)):
632 for p in sorted(list(paddings)):
633 for d in sorted(list(dilations)):
634 if n % sparsity == 0:
635 # Determine the output shape
636 oh = (
637 ifm_shape[1]
638 - filter_shape[1]
639 - (filter_shape[1] - 1) * (d[0] - 1)
640 + 2 * p[0]
641 ) // s[0] + 1
642 ow = (
643 ifm_shape[2]
644 - filter_shape[2]
645 - (filter_shape[2] - 1) * (d[1] - 1)
646 + 2 * p[1]
647 ) // s[1] + 1
648 os = [ifm_shape[0], oh, ow, filter_shape[0]]
649 arg_list.append(
650 (
651 "st{}_pad{}_dilat{}_os{}".format(
652 "".join([str(x) for x in s]),
653 "".join([str(x) for x in p]),
654 "".join([str(x) for x in d]),
655 "x".join([str(x) for x in os]),
656 ),
657 [s, p, d, os],
658 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800659 )
Les Bell7aa69f42021-09-20 10:44:07 +0100660 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700661
662 return arg_list
663
664 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100665 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700666 arg_list = []
667 rank = len(shapeList[0])
668
Les Bell7ffccce2021-07-28 15:37:02 +0100669 # Exhaustively test combinations of padding on each side of each dimension
670 # - the range of padding values is defined by pad_min and pad_max
671 # - for padding >9, the name format needs to be more distinctive
672 pad_min, pad_max = 0, 1
673 pad_values = [x for x in range(pad_min, pad_max + 1)]
Matthew Haddone807aae2021-10-11 18:12:58 +0100674 if error_name == ErrorIf.PadSmallerZero:
675 pad_values = [x for x in range(-2, 0)]
Les Bell7ffccce2021-07-28 15:37:02 +0100676 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
677 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700678
Les Bell7ffccce2021-07-28 15:37:02 +0100679 for paddings in shape_pad_values:
680 name = "pad"
681 for r in range(rank):
682 before, after = paddings[r]
683 name = f"{name}{before}{after}"
684 arg_list.append((name, [np.array(paddings)]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700685
686 return arg_list
687
688 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100689 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700690 arg_list = []
691
692 shape = shapeList[0]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100693 if error_name != ErrorIf.WrongRank:
694 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700695
Les Bell7aa69f42021-09-20 10:44:07 +0100696 # Generate comprehensive argument lists
697 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
698 paddings = {x for x in itertools.product(*([p_vals] * 4))}
699 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
700 strides = {x for x in itertools.product(*([s_vals] * 2))}
701 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
702 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700703
Les Bell7aa69f42021-09-20 10:44:07 +0100704 # add some oversize argument values
705 bigStride = 7
706 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
707 bigKernel = 6
708 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
709 if max(shape) < 64:
710 # padding must be less than the kernel size
711 bigPadding = bigKernel - 1
712 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700713
Les Bell7aa69f42021-09-20 10:44:07 +0100714 # There are too many parameter combinations, so generate them sparsely
715 sparsity = len(paddings) * len(strides) * len(kernels) // 500 + 1
716 n = 0
717 for s in sorted(list(strides)):
718 for p in sorted(list(paddings)):
719 for k in sorted(list(kernels)):
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100720 if error_name in [ErrorIf.StrideSmallerOne, ErrorIf.KernelSmallerOne, ErrorIf.PadSmallerZero, ErrorIf.PadLargerEqualKernel]:
721 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(testGen, error_name, s, p, k)
722 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
723 arg_list.append(
724 (
725 "st{}_kern{}_pad{}".format(
726 "".join([str(x) for x in sNew]),
727 "".join([str(x) for x in kNew]),
728 "".join([str(x) for x in pNew]),
729 ),
730 [sNew, pNew, kNew],
731 )
732 )
733 elif (n % sparsity == 0
Les Bell7aa69f42021-09-20 10:44:07 +0100734 # padding must not exceed the kernel size
735 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
736 # the padded shape must exceed the kernel size
737 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
738 ):
739 arg_list.append(
740 (
741 "st{}_kern{}_pad{}".format(
742 "".join([str(x) for x in s]),
743 "".join([str(x) for x in k]),
744 "".join([str(x) for x in p]),
745 ),
746 [s, p, k],
747 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800748 )
Les Bell7aa69f42021-09-20 10:44:07 +0100749 n += 1
750
Eric Kunzee5e26762020-10-13 16:11:07 -0700751 return arg_list
752
753 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100754 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700755 arg_list = []
756
757 # Enumerate the output types here
758 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800759 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700760 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800761 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700762 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800763 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700764 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800765 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700766 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800767 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700768 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800769 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700770
771 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800772 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700773
774 return arg_list
775
776 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100777 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700778 arg_list = []
779
780 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100781 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100782 if dtype in [DType.UINT8, DType.INT8] and error_name == ErrorIf.OutputZeroPointNotZero:
783 continue
784 if inDtype == DType.UINT8 and dtype != DType.INT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100785 # The only output dtype for UINT8 is INT8, skip all other combinations
786 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100787 if inDtype != DType.INT8 and dtype == DType.UINT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100788 # The only input dtype for UINT8 is INT8, skip all other combinations
789 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100790 if error_name == ErrorIf.WrongOutputType and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype):
791 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100792
Kevin Cheng550ccc52021-03-03 11:21:43 -0800793 for scale32 in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100794 if error_name == ErrorIf.ScaleTrue and scale32 == False:
795 continue
796 elif error_name == ErrorIf.ScaleNotTrue and scale32 == True:
797 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800798 for double_round in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100799 if error_name == ErrorIf.ScaleNotTrue and double_round == False:
800 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800801 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700802
Matthew Haddonc2025212021-10-08 21:21:05 +0100803 if inDtype == DType.INT48 and scale32 and error_name != ErrorIf.ScaleTrue:
Eric Kunzee5e26762020-10-13 16:11:07 -0700804 # Illegal condition. Must be scale32=False
805 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100806 if double_round and not scale32 and error_name != ErrorIf.ScaleNotTrue:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100807 # Illegal condition. ERROR_IF(!scale32 && double_round)
808 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700809
Kevin Cheng550ccc52021-03-03 11:21:43 -0800810 arg_list.append(
811 (
812 "out{}_sc{}_dr{}_pc{}".format(
813 DTypeNames[dtype],
814 int(scale32),
815 int(double_round),
816 int(per_channel),
817 ),
818 [dtype, scale32, double_round, per_channel],
819 )
820 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700821
822 return arg_list
823
Kevin Chengaee1fac2020-11-11 13:54:06 -0800824 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100825 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800826 arg_list = []
827
828 if dtype is DType.INT32:
829 for p in range(testGen.args.num_rand_permutations):
830
831 shift = testGen.randInt(0, 32)
832
Kevin Cheng550ccc52021-03-03 11:21:43 -0800833 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800834 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100835 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800836
837 return arg_list
838
839 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100840 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800841 arg_list = []
842
Kevin Cheng550ccc52021-03-03 11:21:43 -0800843 arg_list.append(("roundTrue", [True]))
844 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800845
846 return arg_list
847
Eric Kunzee5e26762020-10-13 16:11:07 -0700848 # Helper function for reshape. Gets some factors of a larger number.
849 @staticmethod
850 def getFactors(val, start=1):
851 factors = []
852
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100853 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700854 if (val % i) == 0:
855 factors.append(i)
856
857 return factors
858
859 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100860 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700861 arg_list = []
862
863 origShape = shapeList[0]
864
865 totalElements = 1
866 for s in origShape:
867 totalElements *= s
868
869 # This code is NOT fast. Fortunately, the numbers are fairly small.
870 factors = TosaArgGen.getFactors(totalElements)
871
872 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100873 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800874 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700875 continue
876
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100877 found = True
878 # escape_counter breaks while loop if it continues on for too long
879 escape_counter = 0
880 while found:
881 newShape = []
882 # Generate newShape ensuring it isn't a duplicate
883 remainingElements = totalElements
884 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100885 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100886 # pick rank-1 factors
887 newShape.append(shuffledFactors[0])
888 remainingElements = remainingElements // shuffledFactors[0]
889 shuffledFactors = testGen.rng.permutation(
890 TosaArgGen.getFactors(remainingElements)
891 )
892 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700893
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100894 # Toss in a -1 sometimes
895 minusOne = testGen.randInt(0, newRank * 4)
896 if minusOne < newRank:
897 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700898
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100899 # Check for duplicates
900 found = False
901 for name, other_shape in arg_list:
902 if other_shape[0] == newShape:
903 found = True
904 break
905
906 escape_counter += 1
907 if escape_counter >= 100:
908 break
909
910 if not found:
911 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700912
913 return arg_list
914
Eric Kunzee5e26762020-10-13 16:11:07 -0700915 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100916 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700917 arg_list = []
918
919 ifm_shape = shapeList[0]
920
Matthew Haddone807aae2021-10-11 18:12:58 +0100921
922 if error_name == ErrorIf.IndexOutsideBounds:
923 incorrect_large_index = range(len(ifm_shape)+1, 2*len(ifm_shape)+1)
924 incorrect_small_index = range(-len(ifm_shape), 0)
925 permutations = [p for p in itertools.permutations(incorrect_large_index)]
926 permutations.extend([p for p in itertools.permutations(incorrect_small_index)])
927 elif error_name == ErrorIf.IndexUsedTwice:
928 # Create list with a duplicated index
929 perm_range = list(range(len(ifm_shape)))
930 index_choice = testGen.rng.choice(range(len(perm_range)))
931 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
932 permutations = [p for p in itertools.permutations(perm_range)]
933
934
935 else:
936 # Get all permutations
937 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700938
Jeremy Johnsona6185572021-06-21 15:55:35 +0100939 # Limit to possible permutations from shape dimension or argument setting
940 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700941
Jeremy Johnsona6185572021-06-21 15:55:35 +0100942 # Get random permutation generator that uses all permutations
943 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700944
Jeremy Johnsona6185572021-06-21 15:55:35 +0100945 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -0700946 arg_list = [
947 ("perm{}".format(p), [random_permutations[p].tolist()])
948 for p in range(limit)
949 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700950 return arg_list
951
952 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100953 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700954 arg_list = []
955
956 ifm_shape = shapeList[0]
957 rank = len(ifm_shape)
958
959 for p in range(testGen.args.num_rand_permutations):
Matthew Haddone807aae2021-10-11 18:12:58 +0100960 start = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700961 size = []
962
Kevin Cheng550ccc52021-03-03 11:21:43 -0800963 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700964
965 for i in range(rank):
966 if ifm_shape[i] > 1:
Matthew Haddone807aae2021-10-11 18:12:58 +0100967 start.append(testGen.randInt(0, ifm_shape[i]))
968 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700969
970 # Invalid slice size?
971 if size[i] == 0:
972 valid = False
973 else:
Matthew Haddone807aae2021-10-11 18:12:58 +0100974 start.append(0)
Eric Kunzee5e26762020-10-13 16:11:07 -0700975 size.append(1)
976
977 if valid:
Matthew Haddone807aae2021-10-11 18:12:58 +0100978 # If ERROR_IF test required then incorrect start, size will be returned
979 start, size = TosaErrorIfArgGen.eiSliceErrorIf(testGen, error_name, ifm_shape, start, size)
980 arg_list.append(("perm{}".format(p), [start, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700981 return arg_list
982
983 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100984 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700985 arg_list = []
986
987 ifm_shape = shapeList[0]
988 rank = len(ifm_shape)
989
990 for p in range(testGen.args.num_rand_permutations):
991
992 # Pick a few random, but small multiple values
993 # because otherwise this has a tendency to generate
994 # enormous tensors
995 multiples = []
996 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +0100997 if ifm_shape[i] > 1000:
998 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
999 multiples.append(1)
1000 elif max(ifm_shape) > 1000:
1001 multiples.append(2)
1002 else:
1003 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001004 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001005
1006 return arg_list
1007
1008 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001009 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001010 arg_list = []
1011
1012 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +01001013 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001014
1015 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +01001016 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001017 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +01001018 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001019 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +01001020 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001021 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +01001022 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001023 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001024 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001025 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +01001026 elif error_name == ErrorIf.WrongInputType:
1027 # If an incorrect input type is used then we set a 'correct'
1028 # output type to avoid other errors
1029 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -07001030 else:
1031 continue
1032
1033 for outputDType in outputDTypeList:
1034 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -07001035 # Randomly generate legal output dimensions and shift
1036 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +01001037 # A output_dim of 1 will cause offset to exceed allowed range
1038 # so minimum value 2 produced below
1039 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
1040 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
1041 output_dims[0] += 1
1042 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
1043 output_dims[1] += 1
1044
Kevin Cheng77d0f762020-11-24 10:26:32 -08001045 in_center_h = (ifm_shape[1] - 1) / 2.0
1046 in_center_w = (ifm_shape[2] - 1) / 2.0
1047 out_center_h = (output_dims[0] - 1) / 2.0
1048 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -07001049
Kevin Cheng77d0f762020-11-24 10:26:32 -08001050 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
1051 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
1052 fp_offset_y = in_center_h - fp_stride_y * out_center_h
1053 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -07001054
Kevin Cheng77d0f762020-11-24 10:26:32 -08001055 if outputDType == DType.FLOAT:
1056 shift = 0
1057 stride = [0, 0]
1058 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001059 stride_fp = [fp_stride_y, fp_stride_x]
1060 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +01001061
1062 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001063 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001064 testGen,
1065 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001066 mode,
1067 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001068 shapeList,
1069 outputDType,
1070 shift,
1071 stride,
1072 stride_fp,
1073 offset,
1074 offset_fp
1075 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001076 else:
1077 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001078
Kevin Cheng550ccc52021-03-03 11:21:43 -08001079 arg_list.append(
1080 (
1081 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001082 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001083 output_dims[0],
1084 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001085 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001086 stride_fp[0],
1087 stride_fp[1],
1088 offset_fp[0],
1089 offset_fp[1],
1090 ),
1091 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001092 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001093 stride,
1094 offset,
1095 shift,
1096 stride_fp,
1097 offset_fp,
1098 output_dims,
1099 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001100 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001101 ],
1102 )
1103 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001104 else:
1105 shift = 11
1106 unit = float(1 << shift)
1107 stride_y = int(round(fp_stride_y * unit))
1108 stride_x = int(round(fp_stride_x * unit))
1109 offset_y = int(round(fp_offset_y * unit))
1110 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001111
Kevin Cheng550ccc52021-03-03 11:21:43 -08001112 while (
Matthew Haddone86fd342021-09-07 16:12:21 +01001113 stride_y >= (16 << shift)
1114 or stride_x >= (16 << shift)
1115 or offset_y >= (16 << shift)
1116 or offset_x >= (16 << shift)
1117 or offset_y <= (-16 << shift)
1118 or offset_x <= (-16 << shift)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001119 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001120 shift = shift - 1
1121 unit = float(1 << shift)
1122 stride_y = int(round(fp_stride_y * unit))
1123 stride_x = int(round(fp_stride_x * unit))
1124 offset_y = int(round(fp_offset_y * unit))
1125 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001126
Kevin Cheng550ccc52021-03-03 11:21:43 -08001127 stride = [stride_y, stride_x]
1128 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001129
1130 stride_fp = [0.0, 0.0]
1131 offset_fp = [0.0, 0.0]
1132
Matthew Haddone86fd342021-09-07 16:12:21 +01001133 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001134 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001135 testGen,
1136 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001137 mode,
1138 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001139 shapeList,
1140 outputDType,
1141 shift,
1142 stride,
1143 stride_fp,
1144 offset,
1145 offset_fp
1146 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001147 else:
1148 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001149
Kevin Cheng550ccc52021-03-03 11:21:43 -08001150 arg_list.append(
1151 (
1152 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001153 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001154 shift,
1155 output_dims[0],
1156 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001157 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001158 stride[0],
1159 stride[1],
1160 offset[0],
1161 offset[1],
1162 ),
1163 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001164 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001165 stride,
1166 offset,
1167 shift,
1168 stride_fp,
1169 offset_fp,
1170 output_dims,
1171 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001172 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001173 ],
1174 )
1175 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001176
1177 return arg_list
1178
Matthew Haddon1c00b712021-10-01 15:51:03 +01001179 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001180 # CondIf generates the condition values here.
1181 # Convert to tensors in the build function, along with the
1182 # then and else blocks
1183 arg_list = []
1184
1185 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001186 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001187
1188 return arg_list
1189
Matthew Haddon1c00b712021-10-01 15:51:03 +01001190 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001191 # While loop: 0 iterations, 1, more than 1
1192 arg_list = []
1193
1194 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001195 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001196
1197 return arg_list
1198
Matthew Haddone86fd342021-09-07 16:12:21 +01001199class TosaErrorIfArgGen:
1200
1201 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001202 def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
Matthew Haddone86fd342021-09-07 16:12:21 +01001203
1204 if outputDType == DType.FLOAT:
1205 if error_name == ErrorIf.StrideSmallerEqualZero:
1206 stride_fp = testGen.rng.random(size=[2]) - 2
1207 elif error_name == ErrorIf.ShiftNotZero:
1208 shift = testGen.rng.integers(1, 5)
1209 elif error_name == ErrorIf.StrideLargerDimension:
1210 shape = shapeList[0]
1211 transform_height = testGen.rng.choice([False, True])
1212 if transform_height:
1213 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1214 else:
1215 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1216 else:
1217 if error_name == ErrorIf.StrideSmallerEqualZero:
1218 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1219 elif error_name == ErrorIf.ShiftSmallerOne:
1220 shift = testGen.rng.integers(-3, 1)
1221 if shift <= 0:
1222 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1223 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1224 else:
1225 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1226 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1227 elif error_name == ErrorIf.ShiftLargerEleven:
1228 shift = np.int16(testGen.rng.integers(12, 15))
1229 elif error_name == ErrorIf.StrideLargerDimension:
1230 shape = shapeList[0]
1231 transform_height = testGen.rng.choice([False, True])
1232 if transform_height:
1233 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1234 else:
1235 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1236 elif error_name == ErrorIf.StrideLargerEqualMax:
1237 stride = [(16 << shift) + 1, (16 << shift) + 1]
1238 elif error_name == ErrorIf.OffsetLargerEqualMax:
1239 offset = [(16 << shift) + 1, (16 << shift) + 1]
1240 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1241 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1242
Matthew Haddon1c00b712021-10-01 15:51:03 +01001243
Matthew Haddon848efb42021-09-09 12:30:53 +01001244 if error_name == ErrorIf.WrongOutputType:
1245 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1246 incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1247 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1248 incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
1249 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1250 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
1251 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1252 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
1253 elif dtype == DType.FLOAT:
1254 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
1255 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001256
Matthew Haddon848efb42021-09-09 12:30:53 +01001257 return shift, stride, stride_fp, offset, offset_fp, outputDType
1258
Matthew Haddone807aae2021-10-11 18:12:58 +01001259
Matthew Haddon848efb42021-09-09 12:30:53 +01001260 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001261 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
1262 if (error_name == ErrorIf.StrideSmallerOne
1263 # padding must not exceed the kernel size
1264 and pad[0] < kernel[0] and pad[1] < kernel[0] and pad[2] < kernel[1] and pad[3] < kernel[1]):
1265 wrongStride = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1266 return wrongStride, pad, kernel
1267 elif error_name == ErrorIf.PadSmallerZero:
1268 wrongPad = (testGen.rng.choice([-1, -2, -3]),
1269 testGen.rng.choice([-1, -2, -3]),
1270 testGen.rng.choice([-1, -2, -3]),
1271 testGen.rng.choice([-1, -2, -3]))
1272 return stride, wrongPad, kernel
1273 elif error_name == ErrorIf.KernelSmallerOne:
1274 wrongKernel = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1275 return stride, pad, wrongKernel
1276 elif error_name == ErrorIf.PadLargerEqualKernel:
1277 wrongPad = (testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1278 testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1279 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]),
1280 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]))
1281 return stride, wrongPad, kernel
1282 else:
1283 return None, None, None
1284
Matthew Haddone807aae2021-10-11 18:12:58 +01001285
Matthew Haddonc2025212021-10-08 21:21:05 +01001286 @staticmethod
1287 def eiRescaleWrongOutputType(input_dtype, output_dtype):
1288 if input_dtype == DType.INT8:
1289 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1290 return True
1291 if input_dtype in [DType.INT16, DType.INT32]:
1292 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1293 return True
1294 elif input_dtype == DType.INT48:
1295 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1296 return True
1297 elif input_dtype == DType.UINT8:
1298 if output_dtype != DType.INT8:
1299 return True
1300 return False
1301
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001302
1303 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001304 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1305 # Mess up input/output tensors for ERROR_IF checks
1306 if error_name == "WrongInputList":
1307 add_input = testGen.rng.choice([True, False])
1308 if add_input:
1309 input_list.append('eiDummyInput')
1310 else:
1311 input_list = input_list[:-1]
1312 if error_name == "WrongOutputList":
1313 add_output = testGen.rng.choice([True, False])
1314 if add_output:
1315 output_list.append('eiDummyOutput')
1316 else:
1317 output_list = []
1318 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001319
Matthew Haddone807aae2021-10-11 18:12:58 +01001320
Matthew Haddonc2025212021-10-08 21:21:05 +01001321 @staticmethod
1322 def eiRestrictDimension(shape, error_name):
1323 # Restrict dimension size if rank is large for WrongRank Error_If
1324 # This will keep the test sizes reasonably small
1325 if error_name == ErrorIf.WrongRank:
1326 if len(shape) > 4:
1327 shape[4] = 1
1328
1329 return shape
1330
Matthew Haddone807aae2021-10-11 18:12:58 +01001331
1332 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
1333 if error_name == ErrorIf.StartSmallerZero:
1334 newStart = []
1335 for i in range(len(input_shape)):
1336 newStart.append(testGen.rng.choice([-3, -2, -1]))
1337 return newStart, size
1338 elif error_name == ErrorIf.SizeSmallerEqualZero:
1339 newSize = []
1340 for i in range(len(input_shape)):
1341 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
1342 return start, newSize
1343 elif error_name == ErrorIf.StartSizeOutsideBounds:
1344 newStart, newSize = [], []
1345 for i in range(len(input_shape)):
1346 newStart.append(input_shape[i]-1)
1347 newSize.append(testGen.rng.choice([2, 3, 4]))
1348 return newStart, newSize
1349 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
1350 remove = testGen.rng.choice([True, False])
1351 if remove:
1352 newStart = start[1:]
1353 newSize = size[1:]
1354 else:
1355 newStart = start
1356 newStart.append(1)
1357 newSize = size
1358 newSize.append(1)
1359 return newStart, newSize
1360 else:
1361 return start, size
1362
Matthew Haddone86fd342021-09-07 16:12:21 +01001363class TosaErrorValidator:
1364
Matthew Haddon848efb42021-09-09 12:30:53 +01001365 @staticmethod
1366 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
1367 # Check ERROR_IF statements
1368
1369 for val_fcn in validator_fcns:
1370 val_result = val_fcn(True, **kwargs)
1371
1372 validator_name = val_result['error_name']
1373 error_result = val_result['error_result']
1374 error_reason = val_result['error_reason']
1375
1376 if error_result:
1377 if error_name == validator_name:
1378 serializer.setExpectedReturnCode(2, error_reason)
1379 else:
1380 print(f"Multiple ERROR_IF checks hit \nError required: {error_name}, Error_produced: {validator_name}")
1381 return None # Return None to delete test if wrong ERROR_IF is hit
1382 else:
1383 if error_name == validator_name:
1384 print(f"No ERROR_IF hit for {error_name}")
1385 return None
1386
1387 @staticmethod
1388 def evWrongInputType(check=False, **kwargs):
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001389 all_dtypes = {DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
Matthew Haddon848efb42021-09-09 12:30:53 +01001390
1391 # Find the unsupported input data types
1392 assert 'op' in kwargs
1393 op = kwargs['op']
1394 input_dtypes = op['types']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001395
1396 allowed_input_dtypes = {t[0] if isinstance(t, list) else t for t in input_dtypes}
1397 wrong_input_dtypes = list(all_dtypes - allowed_input_dtypes)
Matthew Haddon848efb42021-09-09 12:30:53 +01001398
1399 error_name = ErrorIf.WrongInputType
1400 param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
1401 error_result = False
1402 error_reason = "Input data type not supported for this operator"
1403
1404 if check:
1405 input_dtype = kwargs['input_dtype']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001406 if op['op'] == Op.FULLY_CONNECTED:
1407 if input_dtype not in allowed_input_dtypes:
1408 error_result = True
1409 elif input_dtype not in input_dtypes:
Matthew Haddon848efb42021-09-09 12:30:53 +01001410 error_result = True
1411
1412 info_dict = {
1413 "error_name": error_name,
1414 "error_result": error_result,
1415 "error_reason": error_reason,
1416 "param_reqs": param_reqs
1417 }
1418 return info_dict
1419
1420 @staticmethod
1421 def evWrongOutputType(check=False, **kwargs):
1422 error_name = ErrorIf.WrongOutputType
1423 param_reqs = {"rank": None, "dtype": None, "shape": None}
1424 error_result = False
1425 error_reason = "Output data type not supported for this configuration of operator"
1426
1427 if check:
1428 input_dtype = kwargs['input_dtype']
1429 output_dtype = kwargs['output_dtype']
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001430 op = kwargs['op']
Matthew Haddon848efb42021-09-09 12:30:53 +01001431
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001432 if op['op'] == Op.RESIZE:
1433 mode = kwargs['mode']
1434 if (
1435 (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
1436 (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
1437 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1438 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1439 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1440 ):
1441 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001442 elif op['op'] == Op.RESCALE:
1443 if input_dtype == DType.INT8:
1444 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1445 error_result = True
1446 if input_dtype in [DType.INT16, DType.INT32]:
1447 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1448 error_result = True
1449 elif input_dtype == DType.INT48:
1450 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1451 error_result = True
1452 elif input_dtype == DType.UINT8:
1453 if output_dtype != DType.INT8:
1454 error_result = True
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001455 elif op['op'] in [Op.FULLY_CONNECTED, Op.MATMUL]:
1456 if (
1457 (input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1458 (input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1459 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1460 ):
1461 error_result = True
1462 elif op['op'] == Op.ARGMAX:
1463 if input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] and output_dtype != DType.INT32:
1464 error_result = True
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001465 else:
1466 if output_dtype != input_dtype:
1467 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001468
1469 info_dict = {
1470 "error_name": error_name,
1471 "error_result": error_result,
1472 "error_reason": error_reason,
1473 "param_reqs": param_reqs
1474 }
1475 return info_dict
1476
1477 @staticmethod
1478 def evWrongRank(check=False, **kwargs):
1479 all_ranks = (1, 2, 3, 4, 5)
1480
1481 # Make a list of incorrect ranks
1482 assert 'op' in kwargs
1483 op = kwargs['op']
1484 rmin, rmax = op['rank']
1485 rank_range = range(rmin, rmax + 1)
1486 incorrect_ranks = list(set(all_ranks) - set(rank_range))
Matthew Haddonc2025212021-10-08 21:21:05 +01001487 # Remove small incorrect ranks to avoid index errors
1488 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
Matthew Haddon848efb42021-09-09 12:30:53 +01001489 # Set minimum incorrect rank to 3 to avoid index error
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001490 if op['op'] in [Op.RESIZE]:
Matthew Haddon848efb42021-09-09 12:30:53 +01001491 incorrect_ranks = [3, 5]
1492
1493 error_name = ErrorIf.WrongRank
1494 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1495 error_result = False
1496 error_reason = "Rank not supported for this operator"
1497
1498 if check:
1499 input_shape = kwargs['input_shape']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001500
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001501 if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4:
Matthew Haddon848efb42021-09-09 12:30:53 +01001502 error_result = True
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001503 elif op['op'] == Op.FULLY_CONNECTED and len(input_shape) != 2:
1504 error_result = True
1505 elif op['op'] == Op.MATMUL and len(input_shape) != 3:
1506 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001507 else:
1508 if len(input_shape) not in rank_range:
1509 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001510
1511 info_dict = {
1512 "error_name": error_name,
1513 "error_result": error_result,
1514 "error_reason": error_reason,
1515 "param_reqs": param_reqs
1516 }
1517 return info_dict
1518
1519 @staticmethod
1520 def evWrongInputList(check=False, **kwargs):
1521 error_name = ErrorIf.WrongInputList
1522 param_reqs = {"rank": None, "dtype": None, "shape": None}
1523 error_result = False
1524 error_reason = "Op input list does not match expected input"
1525
1526 if check:
1527 op = kwargs['op']
1528 input_list = kwargs['input_list']
1529 num_operands = kwargs['num_operands']
Matthew Haddone807aae2021-10-11 18:12:58 +01001530 # both PAD, TRANSPOSE add an extra const layer in the build function
1531 if op['op'] in [Op.PAD, Op.TRANSPOSE]:
1532 if len(input_list) != num_operands + 1:
1533 error_result = True
1534 else:
1535 if len(input_list) != num_operands:
1536 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001537
1538 info_dict = {
1539 "error_name": error_name,
1540 "error_result": error_result,
1541 "error_reason": error_reason,
1542 "param_reqs": param_reqs
1543 }
1544 return info_dict
1545
1546 @staticmethod
1547 def evWrongOutputList(check=False, **kwargs):
1548 error_name = ErrorIf.WrongOutputList
1549 param_reqs = {"rank": None, "dtype": None, "shape": None}
1550 error_result = False
1551 error_reason = "Op output list does not match expected output"
1552
1553 if check:
1554 output_list = kwargs['output_list']
1555 # Note this will be incorrect if an operator returns more than one output
1556 if len(output_list) != 1:
1557 error_result = True
1558
1559 info_dict = {
1560 "error_name": error_name,
1561 "error_result": error_result,
1562 "error_reason": error_reason,
1563 "param_reqs": param_reqs
1564 }
1565 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01001566
1567 @staticmethod
1568 def evMaxDimExceeded(check=False, **kwargs):
1569 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01001570 param_reqs = {
1571 "rank": [4,4],
1572 "dtype": [DType.INT8],
1573 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
1574 }
Matthew Haddone86fd342021-09-07 16:12:21 +01001575 error_result = False
1576 error_reason = "At least one maximum dimension is larger than 16384"
1577
1578 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001579 input_shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001580 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
1581 if ((input_shape[1] > 16384) or
1582 (input_shape[2] > 16384) or
1583 (output_shape[0] > 16384) or
1584 (output_shape[1] > 16384)):
1585 error_result = True
1586
1587 info_dict = {
1588 "error_name": error_name,
1589 "error_result": error_result,
1590 "error_reason": error_reason,
1591 "param_reqs": param_reqs
1592 }
1593 return info_dict
1594
1595 @staticmethod
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001596 def evBatchMismatch(check=False, **kwargs):
1597 error_name = ErrorIf.BatchMismatch
1598 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1599 error_result = False
1600 error_reason = "Input batch size not equal to output batch size"
1601
1602 assert 'op' in kwargs
1603 op = kwargs['op']
1604 rmin, rmax = op['rank']
1605 rank_range = range(rmin, rmax + 1)
1606
1607 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001608 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001609 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1610
1611 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
1612 error_result = True
1613
1614 info_dict = {
1615 "error_name": error_name,
1616 "error_result": error_result,
1617 "error_reason": error_reason,
1618 "param_reqs": param_reqs
1619 }
1620 return info_dict
1621
1622 @staticmethod
1623 def evChannelMismatch(check=False, **kwargs):
1624 error_name = ErrorIf.ChannelMismatch
1625 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1626 error_result = False
1627 error_reason = "Input channel size not equal to output channel size"
1628
1629 assert 'op' in kwargs
1630 op = kwargs['op']
1631 rmin, rmax = op['rank']
1632 rank_range = range(rmin, rmax + 1)
1633
1634 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001635 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001636 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1637 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
1638 error_result = True
1639
1640 info_dict = {
1641 "error_name": error_name,
1642 "error_result": error_result,
1643 "error_reason": error_reason,
1644 "param_reqs": param_reqs
1645 }
1646 return info_dict
1647
1648 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001649 def evStrideSmallerEqualZero(check=False, **kwargs):
1650 error_name = ErrorIf.StrideSmallerEqualZero
1651 param_reqs = {"rank": None, "dtype": None, "shape": None}
1652 error_result = False
1653 error_reason = "Stride value smaller than or equal zero"
1654
1655 if check:
1656 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001657 output_dtype = kwargs['output_dtype']
1658 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
1659 stride = kwargs['stride'] # Work around wrong input/output type tests
1660 elif output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001661 stride = kwargs['stride_fp']
Matthew Haddon848efb42021-09-09 12:30:53 +01001662 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1663 stride = kwargs['stride_fp'] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01001664 else:
1665 stride = kwargs['stride']
1666
1667 if min(stride) <= 0:
1668 error_result = True
1669
1670 info_dict = {
1671 "error_name": error_name,
1672 "error_result": error_result,
1673 "error_reason": error_reason,
1674 "param_reqs": param_reqs
1675 }
1676 return info_dict
1677
1678 @staticmethod
1679 def evStrideLargerEqualMax(check=False, **kwargs):
1680 error_name = ErrorIf.StrideLargerEqualMax
1681 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1682 error_result = False
1683 error_reason = "Stride value larger than or equal to maximum value"
1684
1685 if check:
1686 shift = kwargs['shift']
1687 input_dtype = kwargs['input_dtype']
1688 stride = kwargs['stride']
1689 if input_dtype in [DType.INT8, DType.INT16]:
1690 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1691 error_result = True
1692 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1693 error_result = True
1694
1695 info_dict = {
1696 "error_name": error_name,
1697 "error_result": error_result,
1698 "error_reason": error_reason,
1699 "param_reqs": param_reqs
1700 }
1701 return info_dict
1702
1703
1704 @staticmethod
1705 def evStrideLargerDimension(check=False, **kwargs):
1706 error_name = ErrorIf.StrideLargerDimension
1707 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1708 error_result = False
1709 error_reason = "Stride value larger than or equal to H/W dimension"
1710
1711 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001712 shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001713 input_dtype = kwargs['input_dtype']
1714 stride = kwargs['stride_fp']
1715
1716 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1717 error_result = True
1718
1719 info_dict = {
1720 "error_name": error_name,
1721 "error_result": error_result,
1722 "error_reason": error_reason,
1723 "param_reqs": param_reqs
1724 }
1725 return info_dict
1726
1727
1728 @staticmethod
1729 def evOffsetSmallerEqualMin(check=False, **kwargs):
1730 error_name = ErrorIf.OffsetSmallerEqualMin
1731 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1732 error_result = False
1733 error_reason = "Offset value smaller than or equal to minimum value"
1734
1735 if check:
1736 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001737 output_dtype = kwargs['output_dtype']
1738 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001739 offset = kwargs['offset_fp']
1740 else:
1741 offset = kwargs['offset']
1742
1743 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
1744 error_result = True
1745 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
1746 error_result = True
1747
1748 info_dict = {
1749 "error_name": error_name,
1750 "error_result": error_result,
1751 "error_reason": error_reason,
1752 "param_reqs": param_reqs
1753 }
1754 return info_dict
1755
1756 @staticmethod
1757 def evOffsetLargerEqualMax(check=False, **kwargs):
1758 error_name = ErrorIf.OffsetLargerEqualMax
1759 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1760 error_result = False
1761 error_reason = "Offset value larger than or equal to maximum value"
1762
1763 if check:
1764 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001765 output_dtype = kwargs['output_dtype']
1766 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001767 offset = kwargs['offset_fp']
1768 else:
1769 offset = kwargs['offset']
1770
1771 if shift >= 0:
1772 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
1773 error_result = True
1774
1775 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
1776 error_result = True
1777 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
1778 error_result = True
1779
1780 info_dict = {
1781 "error_name": error_name,
1782 "error_result": error_result,
1783 "error_reason": error_reason,
1784 "param_reqs": param_reqs
1785 }
1786 return info_dict
1787
1788 @staticmethod
1789 def evShiftNotZero(check=False, **kwargs):
1790 error_name = ErrorIf.ShiftNotZero
1791 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1792 error_result = False
1793 error_reason = "Shift value must be zero for float input"
1794
1795 if check:
1796 shift = kwargs['shift']
1797 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001798 output_dtype = kwargs['output_dtype']
1799 if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
Matthew Haddone86fd342021-09-07 16:12:21 +01001800 error_result = True
1801
1802 info_dict = {
1803 "error_name": error_name,
1804 "error_result": error_result,
1805 "error_reason": error_reason,
1806 "param_reqs": param_reqs
1807 }
1808 return info_dict
1809
1810
1811 @staticmethod
1812 def evShiftSmallerOne(check=False, **kwargs):
1813 error_name = ErrorIf.ShiftSmallerOne
1814 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1815 error_result = False
1816 error_reason = "Shift value smaller than one"
1817
1818 if check:
1819 shift = kwargs['shift']
1820 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001821 output_dtype = kwargs['output_dtype']
1822 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001823 error_result = True
1824
1825 info_dict = {
1826 "error_name": error_name,
1827 "error_result": error_result,
1828 "error_reason": error_reason,
1829 "param_reqs": param_reqs
1830 }
1831 return info_dict
1832
1833 @staticmethod
1834 def evShiftLargerEleven(check=False, **kwargs):
1835 error_name = ErrorIf.ShiftLargerEleven
1836 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1837 error_result = False
1838 error_reason = "Shift value larger than eleven"
1839
1840 if check:
1841 shift = kwargs['shift']
1842 if shift > 11:
1843 error_result = True
1844
1845 info_dict = {
1846 "error_name": error_name,
1847 "error_result": error_result,
1848 "error_reason": error_reason,
1849 "param_reqs": param_reqs
1850 }
1851 return info_dict
1852
1853
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001854 @staticmethod
1855 def evRankMismatch(check=False, **kwargs):
1856 error_name = ErrorIf.RankMismatch
1857 param_reqs = {"rank": None, "dtype": None, "shape": None}
1858 error_result = False
1859 error_reason = "Input Rank does not match output rank"
1860
1861 if check:
1862 input1_shape = kwargs['input1'].shape
1863 input2_shape = kwargs['input2'].shape
1864 output_shape = kwargs['result_tensor'].shape
1865 if (len(input1_shape) != len(output_shape)) or (len(input2_shape) != len(output_shape)):
1866 error_result = True
1867
1868 info_dict = {
1869 "error_name": error_name,
1870 "error_result": error_result,
1871 "error_reason": error_reason,
1872 "param_reqs": param_reqs
1873 }
1874 return info_dict
1875
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001876 @staticmethod
1877 def evInputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001878 op = kwargs['op']
1879 inputDtypes = op['types'].copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001880 # If inputDtypes is a list then only the first two elements are INT8 inputs
1881 if isinstance(inputDtypes, list):
1882 inputDtypes = inputDtypes[2:]
1883
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001884 if DType.INT8 in inputDtypes:
1885 inputDtypes.remove(DType.INT8)
1886 if DType.UINT8 in inputDtypes:
1887 inputDtypes.remove(DType.UINT8)
1888
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001889 error_name = ErrorIf.InputZeroPointNotZero
1890 param_reqs = {
1891 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001892 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001893 "shape": None
1894 }
1895 error_result = False
1896 error_reason = "Input DType not INT8 and zero point not 0"
1897
1898 if check:
1899 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01001900 if isinstance(kwargs['qinfo'], tuple):
1901 qinfo = kwargs['qinfo']
1902 input_zero_point = qinfo[0]
1903 else:
1904 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1905 qinfo = kwargs['qinfo'].ints
1906 input_zero_point = qinfo[0][1]
1907
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001908 if op['op'] == Op.MATMUL:
1909 input1_dtype = kwargs['input_dtype']
1910 input2_dtype = kwargs['input2_dtype']
1911 qinfo = kwargs['qinfo'].ints
1912 input1_zero_point = qinfo[0][1]
1913 input2_zero_point = qinfo[1][1]
1914 if (input1_dtype != DType.INT8 and input1_zero_point != 0) or (input2_dtype != DType.INT8 and input2_zero_point != 0):
1915 error_result = True
1916 else:
1917 if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
1918 error_result = True
1919
1920 info_dict = {
1921 "error_name": error_name,
1922 "error_result": error_result,
1923 "error_reason": error_reason,
1924 "param_reqs": param_reqs
1925 }
1926 return info_dict
1927
1928
1929 @staticmethod
1930 def evWeightZeroPointNotZero(check=False, **kwargs):
1931 op = kwargs['op']
1932
1933 # exclude inputs with INT8 weights
1934 inputDtypes = [t for t in op['types']
1935 if not isinstance(t, list) or t[1] != DType.INT8]
1936
1937 error_name = ErrorIf.WeightZeroPointNotZero
1938 param_reqs = {
1939 "rank": None,
1940 "dtype": inputDtypes,
1941 "shape": None
1942 }
1943 error_result = False
1944 error_reason = "Weight DType not INT8 and zero point not 0"
1945
1946 if check:
1947 weight_dtype = kwargs['weight_dtype']
1948 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
1949 qinfo = kwargs['qinfo'].ints
1950 weight_zero_point = qinfo[1][1]
1951 if weight_dtype != DType.INT8 and weight_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001952 error_result = True
1953
1954 info_dict = {
1955 "error_name": error_name,
1956 "error_result": error_result,
1957 "error_reason": error_reason,
1958 "param_reqs": param_reqs
1959 }
1960 return info_dict
1961
1962
1963 @staticmethod
1964 def evOutputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001965 op = kwargs['op']
1966 inputDtypes = op['types'].copy()
1967 if DType.INT8 in inputDtypes:
1968 inputDtypes.remove(DType.INT8)
1969 if DType.UINT8 in inputDtypes:
1970 inputDtypes.remove(DType.UINT8)
1971
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001972 error_name = ErrorIf.OutputZeroPointNotZero
1973 param_reqs = {
1974 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001975 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001976 "shape": None
1977 }
1978 error_result = False
1979 error_reason = "Output DType not INT8 and zero point not 0"
1980
1981 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001982 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01001983 output_dtype = kwargs['output_dtype']
1984 if isinstance(kwargs['qinfo'], tuple):
1985 qinfo = kwargs['qinfo']
1986 output_zero_point = qinfo[1]
1987 else:
1988 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1989 qinfo = kwargs['qinfo'].ints
1990 output_zero_point = qinfo[1][1]
1991 if op['op'] == Op.AVG_POOL2D:
1992 if input_dtype != DType.INT8 and output_zero_point != 0:
1993 error_result = True
1994 elif output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001995 error_result = True
1996
1997 info_dict = {
1998 "error_name": error_name,
1999 "error_result": error_result,
2000 "error_reason": error_reason,
2001 "param_reqs": param_reqs
2002 }
2003 return info_dict
2004
Matthew Haddond6ce7252021-09-29 15:35:44 +01002005 @staticmethod
2006 def evAxisSmallerZero(check=False, **kwargs):
2007 error_name = ErrorIf.AxisSmallerZero
2008 param_reqs = {"rank": None, "dtype": None, "shape": None}
2009 error_result = False
2010 error_reason = "Axis smaller than zero"
2011
2012 if check:
2013 axis = kwargs['axis']
2014 if axis < 0:
2015 error_result = True
2016
2017 info_dict = {
2018 "error_name": error_name,
2019 "error_result": error_result,
2020 "error_reason": error_reason,
2021 "param_reqs": param_reqs
2022 }
2023 return info_dict
2024
2025
2026 @staticmethod
2027 def evAxisLargerRank(check=False, **kwargs):
2028 error_name = ErrorIf.AxisLargerRank
2029 param_reqs = {"rank": None, "dtype": None, "shape": None}
2030 error_result = False
2031 error_reason = "Axis larger than rank"
2032
2033 if check:
2034 axis = kwargs['axis']
2035 shape = kwargs['input_shape']
2036 if axis > len(shape):
2037 error_result = True
2038
2039 info_dict = {
2040 "error_name": error_name,
2041 "error_result": error_result,
2042 "error_reason": error_reason,
2043 "param_reqs": param_reqs
2044 }
2045 return info_dict
2046
2047
2048 @staticmethod
2049 def evShapeOfAxisNotOne(check=False, **kwargs):
2050 error_name = ErrorIf.ShapeOfAxisNotOne
2051 param_reqs = {"rank": None, "dtype": None, "shape": None}
2052 error_result = False
2053 error_reason = "shape[axis] is not equal to 1"
2054
2055 if check:
2056 axis = kwargs['axis']
2057 shape = kwargs['output_shape']
2058 if (0 <= axis < len(shape)) and shape[axis] != 1:
2059 error_result = True
2060
2061 info_dict = {
2062 "error_name": error_name,
2063 "error_result": error_result,
2064 "error_reason": error_reason,
2065 "param_reqs": param_reqs
2066 }
2067 return info_dict
2068
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002069
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002070 @staticmethod
2071 def evPadSmallerZero(check=False, **kwargs):
2072 error_name = ErrorIf.PadSmallerZero
2073 param_reqs = {"rank": None, "dtype": None, "shape": None}
2074 error_result = False
2075 error_reason = "At least one pad is smaller than zero"
2076
2077 if check:
Matthew Haddone807aae2021-10-11 18:12:58 +01002078 op = kwargs['op']
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002079 pad = kwargs['pad']
Matthew Haddone807aae2021-10-11 18:12:58 +01002080 if op['op'] == Op.PAD:
2081 for padding in pad:
2082 if min(padding) < 0:
2083 error_result = True
2084 else:
2085 if min(pad) < 0:
2086 error_result = True
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002087
2088 info_dict = {
2089 "error_name": error_name,
2090 "error_result": error_result,
2091 "error_reason": error_reason,
2092 "param_reqs": param_reqs
2093 }
2094 return info_dict
2095
2096
2097 @staticmethod
2098 def evPadLargerEqualKernel(check=False, **kwargs):
2099 error_name = ErrorIf.PadLargerEqualKernel
2100 param_reqs = {"rank": None, "dtype": None, "shape": None}
2101 error_result = False
2102 error_reason = "At least one pad is larger than kernel dimension"
2103
2104 if check:
2105 pad = kwargs['pad']
2106 kernel = kwargs['kernel']
2107 if min(pad) > 0 and min(kernel) > 1:
2108 if pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]:
2109 error_result = True
2110
2111 info_dict = {
2112 "error_name": error_name,
2113 "error_result": error_result,
2114 "error_reason": error_reason,
2115 "param_reqs": param_reqs
2116 }
2117 return info_dict
2118
2119 @staticmethod
2120 def evPoolingOutputShapeMismatch(check=False, **kwargs):
2121 error_name = ErrorIf.PoolingOutputShapeMismatch
2122 param_reqs = {"rank": None, "dtype": None, "shape": None}
2123 error_result = False
2124 error_reason = "Mismatch between output shape provided and expected output shape"
2125
2126 if check:
2127 pad = kwargs['pad']
2128 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
2129
2130 kernel = kwargs['kernel']
2131 kernel_y, kernel_x = kernel[0], kernel[1]
2132
2133 input_shape = kwargs['input_shape']
2134 IH, IW = input_shape[1], input_shape[2]
2135
2136 output_shape = kwargs['output_shape']
2137 OH, OW = output_shape[1], output_shape[2]
2138
2139 stride = kwargs['stride']
2140 stride_y, stride_x = stride[0], stride[1]
2141
2142 # calculate correct height, width dimensions
2143 if stride_x != 0 and stride_y != 0:
2144 y_correct = (IH + pad_top + pad_bottom + stride_y - kernel_y) // stride_y
2145 x_correct = (IW + pad_left + pad_right + stride_x - kernel_x) // stride_x
2146
2147 # ensure parameters are valid
2148 params_valid = (min(kernel) >= 1 and min(stride) >= 1 and min(pad) >= 0
2149 and not (pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]))
2150
2151 if params_valid and (OH != y_correct or OW != x_correct):
2152 error_result = True
2153
2154 info_dict = {
2155 "error_name": error_name,
2156 "error_result": error_result,
2157 "error_reason": error_reason,
2158 "param_reqs": param_reqs
2159 }
2160 return info_dict
2161
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002162 @staticmethod
2163 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
2164 error_name = ErrorIf.ArgmaxOutputShapeMismatch
2165 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2166 error_result = False
2167 error_reason = "Mismatch between output shape provided and expected output shape"
2168
2169 if check:
2170 output_shape = kwargs['output_shape']
2171 input_shape = kwargs['input_shape']
2172 axis = kwargs['axis']
2173
2174 dimension_match = True
2175 axis_shift = 0
2176
2177 # Check that rank is correct before trying to check dimensions
2178 if (len(input_shape) - 1) == len(output_shape):
2179 for i in range(len(input_shape)):
2180 if i == axis:
2181 axis_shift = 1
2182 continue
2183 if input_shape[i] != output_shape[i - axis_shift]:
2184 dimension_match = False
2185
2186 if not dimension_match:
2187 error_result = True
2188
2189 info_dict = {
2190 "error_name": error_name,
2191 "error_result": error_result,
2192 "error_reason": error_reason,
2193 "param_reqs": param_reqs
2194 }
2195 return info_dict
2196
2197 @staticmethod
2198 def evArgmaxOutputRankMismatch(check=False, **kwargs):
2199 error_name = ErrorIf.ArgmaxOutputRankMismatch
2200 param_reqs = {"rank": None, "dtype": None, "shape": None}
2201 error_result = False
2202 error_reason = "Mismatch between output shape provided and expected output shape"
2203
2204 if check:
2205 output_shape = kwargs['output_shape']
2206 input_shape = kwargs['input_shape']
2207 axis = kwargs['axis']
2208 valid_params = axis >= 0 and axis < len(input_shape)
2209
2210 if valid_params and (len(input_shape) - 1) != len(output_shape):
2211 error_result = True
2212
2213 info_dict = {
2214 "error_name": error_name,
2215 "error_result": error_result,
2216 "error_reason": error_reason,
2217 "param_reqs": param_reqs
2218 }
2219 return info_dict
2220
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002221
2222 @staticmethod
2223 def evKernelSmallerOne(check=False, **kwargs):
2224 error_name = ErrorIf.KernelSmallerOne
2225 param_reqs = {"rank": None, "dtype": None, "shape": None}
2226 error_result = False
2227 error_reason = "At least one kernel dimension is smaller than zero"
2228
2229 if check:
2230 kernel = kwargs['kernel']
2231 if min(kernel) < 1:
2232 error_result = True
2233
2234 info_dict = {
2235 "error_name": error_name,
2236 "error_result": error_result,
2237 "error_reason": error_reason,
2238 "param_reqs": param_reqs
2239 }
2240 return info_dict
2241
2242 @staticmethod
2243 def evStrideSmallerOne(check=False, **kwargs):
2244 error_name = ErrorIf.StrideSmallerOne
2245 param_reqs = {"rank": None, "dtype": None, "shape": None}
2246 error_result = False
2247 error_reason = "At least one stride dimension is smaller than zero"
2248
2249 if check:
2250 stride = kwargs['stride']
2251 if min(stride) < 1:
2252 error_result = True
2253
2254 info_dict = {
2255 "error_name": error_name,
2256 "error_result": error_result,
2257 "error_reason": error_reason,
2258 "param_reqs": param_reqs
2259 }
2260 return info_dict
2261
Matthew Haddonc2025212021-10-08 21:21:05 +01002262 @staticmethod
2263 def evScaleTrue(check=False, **kwargs):
2264 error_name = ErrorIf.ScaleTrue
2265 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
2266 error_result = False
2267 error_reason = "Scale set to true but input type is INT48"
2268
2269 if check:
2270 input_dtype = kwargs['input_dtype']
2271 scale32 = kwargs['scale32']
2272 if scale32 and input_dtype == DType.INT48:
2273 error_result = True
2274
2275 info_dict = {
2276 "error_name": error_name,
2277 "error_result": error_result,
2278 "error_reason": error_reason,
2279 "param_reqs": param_reqs
2280 }
2281 return info_dict
2282
2283 @staticmethod
2284 def evScaleNotTrue(check=False, **kwargs):
2285 error_name = ErrorIf.ScaleNotTrue
2286 param_reqs = {"rank": None, "dtype": None, "shape": None}
2287 error_result = False
2288 error_reason = "Scale set to false but double round set to true"
2289
2290 if check:
2291 scale32 = kwargs['scale32']
2292 double_round = kwargs['double_round']
2293 if not scale32 and double_round:
2294 error_result = True
2295
2296 info_dict = {
2297 "error_name": error_name,
2298 "error_result": error_result,
2299 "error_reason": error_reason,
2300 "param_reqs": param_reqs
2301 }
2302 return info_dict
2303
Matthew Haddone807aae2021-10-11 18:12:58 +01002304 @staticmethod
2305 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
2306 error_name = ErrorIf.TensorSizeInputOutputMismatch
2307 param_reqs = {"rank": None, "dtype": None, "shape": None}
2308 error_result = False
2309 error_reason = "Input tensor size does not match output tensor size"
2310
2311 if check:
2312 input_shape = kwargs['input_shape']
2313 output_shape = kwargs['output_shape']
2314 input_size = np.prod(input_shape)
2315 output_size = np.prod(output_shape)
2316 if input_size != output_size:
2317 error_result = True
2318
2319 info_dict = {
2320 "error_name": error_name,
2321 "error_result": error_result,
2322 "error_reason": error_reason,
2323 "param_reqs": param_reqs
2324 }
2325 return info_dict
2326
2327 @staticmethod
2328 def evStartSmallerZero(check=False, **kwargs):
2329 error_name = ErrorIf.StartSmallerZero
2330 param_reqs = {"rank": None, "dtype": None, "shape": None}
2331 error_result = False
2332 error_reason = "Starting point smaller than zero"
2333
2334 if check:
2335 input_shape = kwargs['input_shape']
2336 start = kwargs['start']
2337 rank = len(input_shape)
2338 if len(start) == rank:
2339 for index in range(rank):
2340 if start[index] < 0:
2341 error_result = True
2342
2343 info_dict = {
2344 "error_name": error_name,
2345 "error_result": error_result,
2346 "error_reason": error_reason,
2347 "param_reqs": param_reqs
2348 }
2349 return info_dict
2350
2351
2352 @staticmethod
2353 def evSizeSmallerEqualZero(check=False, **kwargs):
2354 error_name = ErrorIf.SizeSmallerEqualZero
2355 param_reqs = {"rank": None, "dtype": None, "shape": None}
2356 error_result = False
2357 error_reason = "Size smaller than or equal to zero"
2358
2359 if check:
2360 input_shape = kwargs['input_shape']
2361 size = kwargs['size']
2362 rank = len(input_shape)
2363 if len(size) == rank:
2364 for index in range(rank):
2365 if size[index] <= 0:
2366 error_result = True
2367
2368 info_dict = {
2369 "error_name": error_name,
2370 "error_result": error_result,
2371 "error_reason": error_reason,
2372 "param_reqs": param_reqs
2373 }
2374 return info_dict
2375
2376
2377 @staticmethod
2378 def evStartSizeOutsideBounds(check=False, **kwargs):
2379 error_name = ErrorIf.StartSizeOutsideBounds
2380 param_reqs = {"rank": None, "dtype": None, "shape": None}
2381 error_result = False
2382 error_reason = "starting point plus size larger than input dimension"
2383
2384 if check:
2385 input_shape = kwargs['input_shape']
2386 start = kwargs['start']
2387 size = kwargs['size']
2388 rank = len(input_shape)
2389 if len(start) == rank and len(size) == rank:
2390 for index in range(rank):
2391 if start[index] + size[index] > input_shape[index]:
2392 error_result = True
2393
2394 info_dict = {
2395 "error_name": error_name,
2396 "error_result": error_result,
2397 "error_reason": error_reason,
2398 "param_reqs": param_reqs
2399 }
2400 return info_dict
2401
2402
2403 @staticmethod
2404 def evSizeOutputShapeMismatch(check=False, **kwargs):
2405 error_name = ErrorIf.SizeOutputShapeMismatch
2406 param_reqs = {"rank": None, "dtype": None, "shape": None}
2407 error_result = False
2408 error_reason = "Size does not match output dimension"
2409
2410 if check:
2411 input_shape = kwargs['input_shape']
2412 output_shape = kwargs['output_shape']
2413 size = kwargs['size']
2414 rank = len(input_shape)
2415 if len(size) == rank:
2416 for index in range(rank):
2417 if size[index] != output_shape[index]:
2418 error_result = True
2419
2420 info_dict = {
2421 "error_name": error_name,
2422 "error_result": error_result,
2423 "error_reason": error_reason,
2424 "param_reqs": param_reqs
2425 }
2426 return info_dict
2427
2428 @staticmethod
2429 def evInputSizeStartLengthMismatch(check=False, **kwargs):
2430 error_name = ErrorIf.InputSizeStartLengthMismatch
2431 param_reqs = {"rank": None, "dtype": None, "shape": None}
2432 error_result = False
2433 error_reason = "rank of input not equal to length of start or size"
2434
2435 if check:
2436 input_shape = kwargs['input_shape']
2437 start = kwargs['start']
2438 size = kwargs['size']
2439 rank = len(input_shape)
2440 if rank != len(start) or rank != len(size):
2441 error_result = True
2442
2443 info_dict = {
2444 "error_name": error_name,
2445 "error_result": error_result,
2446 "error_reason": error_reason,
2447 "param_reqs": param_reqs
2448 }
2449 return info_dict
2450
2451 @staticmethod
2452 def evIndexOutsideBounds(check=False, **kwargs):
2453 error_name = ErrorIf.IndexOutsideBounds
2454 param_reqs = {"rank": None, "dtype": None, "shape": None}
2455 error_result = False
2456 error_reason = "Index outside of allowed bounds"
2457
2458 if check:
2459 input_shape = kwargs['input_shape']
2460 perms = kwargs['perms']
2461 rank = len(input_shape)
2462
2463 for index in perms:
2464 if index < 0 or index > rank:
2465 error_result = True
2466
2467 info_dict = {
2468 "error_name": error_name,
2469 "error_result": error_result,
2470 "error_reason": error_reason,
2471 "param_reqs": param_reqs
2472 }
2473 return info_dict
2474
2475 @staticmethod
2476 def evIndexUsedTwice(check=False, **kwargs):
2477 error_name = ErrorIf.IndexUsedTwice
2478 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2479 error_result = False
2480 error_reason = "Index used multiple times"
2481
2482 if check:
2483 input_shape = kwargs['input_shape']
2484 perms = kwargs['perms']
2485 rank = len(input_shape)
2486
2487 unique_indices = []
2488 for index in perms:
2489 if index in unique_indices:
2490 error_result = True
2491 else:
2492 unique_indices.append(index)
2493
2494 info_dict = {
2495 "error_name": error_name,
2496 "error_result": error_result,
2497 "error_reason": error_reason,
2498 "param_reqs": param_reqs
2499 }
2500 return info_dict
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002501
2502
Matthew Haddonb724efc2021-08-25 16:40:29 +01002503class TosaInvalidValidator:
2504
2505 @staticmethod
2506 def ivWrongDataTypeOrModeResize(**kwargs):
2507 input_dtype = kwargs["input_dtype"]
2508 args = kwargs["args"]
2509 mode = args[0]
2510 stride = args[1]
2511 stride_fp = args[4]
2512 output_dtype = args[8]
2513
2514 if mode == ResizeMode.BILINEAR:
2515 # Invalid output data type / Invalid input datatype
2516 return (
2517 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
2518 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
2519 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
2520 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2521 )
2522 elif mode == ResizeMode.NEAREST:
2523 # Invalid output data type / Invalid input datatype
2524 return (
2525 (input_dtype != output_dtype) or
2526 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2527 )
2528 else:
2529 # Invalid resize mode
2530 return True
2531
2532 @staticmethod
2533 def ivBadStride(**kwargs):
2534 input_dtype = kwargs["input_dtype"]
2535 args = kwargs["args"]
2536 stride_x = args[1][0]
2537 stride_y = args[1][1]
2538 stride_fp_x = args[4][0]
2539 stride_fp_y = args[4][1]
2540
2541 if input_dtype == DType.FLOAT:
2542 if stride_fp_x <= 0 or stride_fp_y <= 0:
2543 # Negative or zero stride
2544 return True
2545 else:
2546 if stride_x <= 0 or stride_y <= 0:
2547 # Negative or zero stride
2548 return True
2549 return False
2550
2551
Matthew Haddonb724efc2021-08-25 16:40:29 +01002552 @staticmethod
2553 def ivHeightWidthSmallerZero(**kwargs):
2554 opName = kwargs['opName']
2555
2556 inputShapes = kwargs['shapeList']
2557 input = inputShapes[0]
2558 if not opName.endswith("pool2d"):
2559 filter = inputShapes[1]
2560
2561 args = kwargs['args']
2562 strides = args[0]
2563 padding = args[1]
2564 dilations = args[2]
2565 if opName.endswith("pool2d"):
2566 kernel = args[2]
2567
2568 if opName.startswith('conv2d'):
2569 h = (
2570 input[1]
2571 - filter[1]
2572 - (filter[1] - 1) * (dilations[0] - 1)
2573 + padding[0]
2574 + padding[1]
2575 ) // strides[0] + 1
2576
2577 w = (
2578 input[2]
2579 - filter[2]
2580 - (filter[2] - 1) * (dilations[1] - 1)
2581 + padding[2]
2582 + padding[3]
2583 ) // strides[1] + 1
2584 elif opName.startswith("depthwise_conv2d"):
2585 h = (
2586 input[1]
2587 - filter[0]
2588 - (filter[0] - 1) * (dilations[0] - 1)
2589 + padding[0]
2590 + padding[1]
2591 ) // strides[0] + 1
2592
2593 w = (
2594 input[2]
2595 - filter[1]
2596 - (filter[1] - 1) * (dilations[1] - 1)
2597 + padding[2]
2598 + padding[3]
2599 ) // strides[1] + 1
2600 elif opName.endswith("pool2d"):
2601 h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
2602 w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
2603 else:
2604 assert False, "Unrecognized Op"
2605
2606 if h <= 0 or w <= 0:
2607 # Invalid parameter combination
2608 return True
2609 return False
2610
2611 @staticmethod
2612 def ivNonPositiveOutputShape(**kwargs):
2613 args = kwargs['args']
2614 output_shape = args[3]
2615 if output_shape[1] <= 0 or output_shape[2] <= 0:
2616 # Negative output shape
2617 return True
2618 return False
2619
2620
Kevin Cheng550ccc52021-03-03 11:21:43 -08002621
Eric Kunzee5e26762020-10-13 16:11:07 -07002622class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002623 # Maximum rank of tensor supported by test generator.
2624 TOSA_TENSOR_MAX_RANK = 6
2625
Eric Kunzee5e26762020-10-13 16:11:07 -07002626 def __init__(self, args):
2627 self.args = args
2628 self.basePath = args.output_dir
2629 self.random_seed = args.random_seed
2630 self.ser = None
2631 self.rng = np.random.default_rng(self.random_seed)
2632 self.createDynamicOpLists()
2633 self.initOpListDefaults()
2634 self.quantGen = TosaQuantGen()
2635 # Force makeShape to do a specific starting shape
2636 self.targetted_shape = None
2637
2638 def createSerializer(self, opName, testPath):
2639 self.testPath = os.path.join(opName, testPath)
2640
2641 fullPath = os.path.join(self.basePath, self.testPath)
2642 os.makedirs(fullPath, exist_ok=True)
2643 self.ser = ts.TosaSerializer(fullPath)
2644
2645 def getSerializer(self):
2646 return self.ser
2647
2648 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002649 with open(
2650 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
2651 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07002652 fd.write(self.ser.serialize())
2653
Kevin Cheng550ccc52021-03-03 11:21:43 -08002654 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
2655 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07002656
Matthew Haddon74567092021-07-16 15:38:20 +01002657 def resetRNG(self, seed=None):
2658 if seed == None:
2659 seed = self.random_seed + 1
2660 self.rng = np.random.default_rng(seed)
2661
Eric Kunzee5e26762020-10-13 16:11:07 -07002662 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07002663 if dtype == DType.BOOL:
2664 np_dt = np.bool
2665 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07002666 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07002667 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07002668 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002669 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002670 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
2671 elif dtype == DType.UINT8:
2672 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002673 elif dtype == DType.INT16:
2674 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
2675 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002676 return np.int32(
2677 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
2678 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002679 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002680 return np.int64(
2681 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
2682 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002683 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002684 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002685 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002686 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002687
Kevin Cheng989cb052021-04-28 16:29:44 -07002688 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07002689 placeholders = []
2690
Kevin Cheng989cb052021-04-28 16:29:44 -07002691 assert len(shape_list) == len(dtype_list)
2692
2693 for idx, shape in enumerate(shape_list):
2694 arr = self.getRandTensor(shape, dtype_list[idx])
2695 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002696
2697 return placeholders
2698
Kevin Cheng989cb052021-04-28 16:29:44 -07002699 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07002700 consts = []
2701
Kevin Cheng989cb052021-04-28 16:29:44 -07002702 assert len(shape_list) == len(dtype_list)
2703
2704 for idx, shape in enumerate(shape_list):
2705 arr = self.getRandTensor(shape, dtype_list[idx])
2706 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002707
2708 return consts
2709
2710 def makeShape(self, rank):
2711 if self.targetted_shape:
2712 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002713 return np.int32(
2714 self.rng.integers(
2715 low=self.args.tensor_shape_range[0],
2716 high=self.args.tensor_shape_range[1],
2717 size=rank,
2718 )
2719 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002720
2721 def setTargetShape(self, shape):
2722 self.targetted_shape = shape
2723
2724 def randInt(self, low=0, high=256):
2725 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
2726
2727 def getRandNumberDType(self, dtype):
2728 if dtype == DType.FLOAT:
2729 return self.rng.random()
2730 elif dtype == DType.BOOL:
2731 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07002732 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07002733 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07002734 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07002735 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002736 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07002737 elif dtype == DType.INT16:
2738 low, high = (-32768, 32768)
2739 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002740 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07002741 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002742 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07002743 # Special size
2744 return np.int64(self.rng.integers(low, high, size=1))[0]
2745 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002746 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002747
2748 return np.int32(self.rng.integers(low, high, size=1))[0]
2749
2750 def shapeStr(self, shape):
2751
2752 sStr = []
2753 # Convert to strings
2754 for i in shape:
2755 sStr.append(str(i))
2756
Kevin Cheng550ccc52021-03-03 11:21:43 -08002757 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002758
2759 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07002760 if isinstance(t, list):
2761 assert len(t) >= 2
2762 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002763 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002764 if t == DType.BOOL:
2765 return "b"
2766 elif t == DType.INT4:
2767 return "i4"
2768 elif t == DType.INT8:
2769 return "i8"
2770 elif t == DType.UINT8:
2771 return "u8"
2772 elif t == DType.INT16:
2773 return "i16"
2774 elif t == DType.INT32:
2775 return "i32"
2776 elif t == DType.INT48:
2777 return "i48"
2778 elif t == DType.FLOAT:
2779 return "float"
2780 else:
2781 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002782
2783 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002784 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08002785 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07002786 return 4
2787 elif t == DType.INT8:
2788 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08002789 elif t == DType.UINT8:
2790 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07002791 elif t == DType.INT16:
2792 return 16
2793 elif t == DType.INT32:
2794 return 32
2795 elif t == DType.INT48:
2796 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +01002797 elif t == DType.FLOAT:
2798 return 32
2799 elif t == DType.BOOL:
2800 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002801 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002802 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002803
2804 # Argument generators
2805 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
2806 # Where the string descriptor is used to generate the test name and
2807 # The build_fcn_arg_list is expanded and passed to the operator test
2808 # build function
2809
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002810 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
2811 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
2812
Matthew Haddon848efb42021-09-09 12:30:53 +01002813 # build_placeholder returns an int, ABS/other ops does not
2814 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002815 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
2816 return result_tens
2817 elif op['op'] == Op.IDENTITY:
2818 self.ser.addOperator(op['op'], a.name, result_tens.name, None, qinfo)
2819 return result_tens
2820
2821 # Ensure new output type has correct qinfo
2822 if error_name == ErrorIf.WrongOutputType:
2823 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
2824 qinfo = ts.TosaSerializerQuantInfo()
2825 qinfo.UnaryQuantInfo(
2826 TosaQuantGen.getQinfo(self, a.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
2827 )
2828
2829 # Invalidate Input/Output list for error if checks.
2830 input_list = [a.name]
2831 output_list = [result_tens.name]
2832 pCount, cCount = op["operands"]
2833 num_operands = pCount + cCount
2834 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2835
2836 TosaErrorValidator.evValidateErrorIfs(
2837 self.ser,
2838 validator_fcns,
2839 error_name,
2840 op=op,
2841 input_dtype=a.dtype,
2842 output_dtype=result_tens.dtype,
2843 qinfo = qinfo,
2844 result_tensor = result_tens,
2845 input_list=input_list,
2846 output_list=output_list,
2847 num_operands=num_operands,
2848 )
2849
2850 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002851 return result_tens
2852
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002853 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
2854 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
2855
2856
2857 # Invalidate Input/Output list for error if checks.
2858 input_list = [a.name, b.name]
2859 output_list = [result_tens.name]
2860 pCount, cCount = op["operands"]
2861 num_operands = pCount + cCount
2862 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2863
2864 TosaErrorValidator.evValidateErrorIfs(
2865 self.ser,
2866 validator_fcns,
2867 error_name,
2868 op=op,
2869 input1 = a,
2870 input2 = b,
2871 input_dtype = a.dtype,
2872 output_dtype = result_tens.dtype,
2873 result_tensor = result_tens,
2874 input_list=input_list,
2875 output_list=output_list,
2876 num_operands=num_operands,
2877 )
2878
2879 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07002880 return result_tens
2881
2882 def build_binary_nonbroadcast(self, op, a, b):
2883 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002884 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002885 return result_tens
2886
Kevin Chengaee1fac2020-11-11 13:54:06 -08002887 def build_arithmetic_right_shift(self, op, a, b, round):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002888 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002889
2890 attr = ts.TosaSerializerAttribute()
2891 attr.ArithmeticRightShiftAttribute(round)
2892
Matthew Haddon848efb42021-09-09 12:30:53 +01002893 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002894 return result_tens
2895
2896 def build_mul(self, op, a, b, shift):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002897 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Eric Kunzee5e26762020-10-13 16:11:07 -07002898
2899 # Special for multiply:
2900 # Force the result to INT32 for INT types
2901 if a.dtype != DType.FLOAT:
2902 result_tens.setDtype(DType.INT32)
2903
Kevin Chengaee1fac2020-11-11 13:54:06 -08002904 attr = ts.TosaSerializerAttribute()
2905 attr.MulAttribute(shift)
2906
Matthew Haddon848efb42021-09-09 12:30:53 +01002907 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002908 return result_tens
2909
2910 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002911 # Constant size depending on type, random values
2912 if a.dtype == DType.INT16:
Kevin Chengacb550f2021-06-29 15:32:19 -07002913 table_dtype = DType.INT16
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002914 table_arr = self.getRandTensor([513], table_dtype)
2915 else:
2916 assert a.dtype == DType.INT8
2917 table_dtype = DType.INT8
2918 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002919
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002920 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
2921 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01002922 self.ser.addOperator(op['op'], [a.name, table_tens.name], [result_tens.name], None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002923
2924 return result_tens
2925
2926 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07002927 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002928 self.ser.addOperator(op['op'], [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002929 return result_tens
2930
2931 def build_comparison(self, op, a, b):
2932 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002933 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002934 return result_tens
2935
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002936 def build_argmax(self, op, a, axis, validator_fcns, error_name):
2937 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
2938
2939 # Invalidate Input/Output list for error if checks.
2940 input_list = [a.name]
2941 output_list = [result_tens.name]
2942 pCount, cCount = op["operands"]
2943 num_operands = pCount + cCount
2944 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2945
2946 TosaErrorValidator.evValidateErrorIfs(
2947 self.ser,
2948 validator_fcns,
2949 error_name,
2950 op=op,
2951 axis=axis,
2952 input_shape = a.shape,
2953 input_dtype = a.dtype,
2954 output_shape = result_tens.shape,
2955 output_dtype = result_tens.dtype,
2956 result_tensor = result_tens,
2957 input_list=input_list,
2958 output_list=output_list,
2959 num_operands=num_operands,
2960 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002961
2962 attr = ts.TosaSerializerAttribute()
2963 attr.AxisAttribute(axis)
2964
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002965 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002966 return result_tens
2967
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002968 def build_pool2d(self, op, input, stride, pad, kernel, validator_fcns=None, error_name=None, qinfo=None):
2969 result_tens = OutputShaper.pool2dOp(self.ser, self.rng, input, kernel, stride, pad, error_name)
2970
2971 # Ensure new output type has correct qinfo
2972 if error_name == ErrorIf.WrongInputType:
2973 if input.dtype not in [DType.INT8, DType.UINT8]:
2974 qinfo = ts.TosaSerializerQuantInfo()
2975 qinfo.UnaryQuantInfo(
2976 TosaQuantGen.getQinfo(self, input.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
2977 )
2978
2979 # Invalidate Input/Output list for error if checks.
2980 input_list = [input.name]
2981 output_list = [result_tens.name]
2982 pCount, cCount = op["operands"]
2983 num_operands = pCount + cCount
2984 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2985
2986 TosaErrorValidator.evValidateErrorIfs(
2987 self.ser,
2988 validator_fcns,
2989 error_name,
2990 op=op,
2991 input_shape=input.shape,
2992 input_dtype=input.dtype,
2993 output_shape=result_tens.shape,
2994 output_dtype=result_tens.dtype,
2995 kernel=kernel,
2996 stride=stride,
2997 pad=pad,
2998 qinfo = qinfo,
2999 result_tensor = result_tens,
3000 input_list=input_list,
3001 output_list=output_list,
3002 num_operands=num_operands,
3003 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003004
3005 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003006 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07003007
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003008 self.ser.addOperator(op['op'], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003009 return result_tens
3010
3011 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003012 assert len(padding) == 4
3013 result_tens = OutputShaper.conv2dOp(
3014 self.ser, ifm, filter, strides, padding, dilations
3015 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003016
3017 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003018 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07003019
Kevin Cheng550ccc52021-03-03 11:21:43 -08003020 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003021 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003022 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003023 return result_tens
3024
Kevin Cheng1533b852021-09-01 12:51:58 -07003025 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
3026 assert len(padding) == 6
3027 result_tens = OutputShaper.conv3dOp(
3028 self.ser, ifm, filter, strides, padding, dilations
3029 )
3030
3031 attr = ts.TosaSerializerAttribute()
3032 attr.ConvAttribute(padding, strides, dilations)
3033
3034 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003035 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng1533b852021-09-01 12:51:58 -07003036 )
3037 return result_tens
3038
Kevin Cheng550ccc52021-03-03 11:21:43 -08003039 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07003040 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003041 ):
3042 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07003043 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
3044
3045 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003046 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003047
Kevin Cheng550ccc52021-03-03 11:21:43 -08003048 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003049 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003050 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003051 return result_tens
3052
Kevin Cheng550ccc52021-03-03 11:21:43 -08003053 def build_depthwise_conv2d(
3054 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
3055 ):
3056 result_tens = OutputShaper.depthwiseConv2dOp(
3057 self.ser, ifm, filter, strides, padding, dilations
3058 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003059
3060 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003061 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07003062
Kevin Cheng550ccc52021-03-03 11:21:43 -08003063 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003064 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003065 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003066 return result_tens
3067
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003068 def build_fully_connected(self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None):
3069 result_tens = OutputShaper.fullyConnectedOp(self.ser, self.rng, ifm, filter, error_name)
3070
3071 # Invalidate Input/Output list for error if checks.
3072 input_list = [ifm.name, filter.name, bias.name]
3073 output_list = [result_tens.name]
3074 pCount, cCount = op["operands"]
3075 num_operands = pCount + cCount
3076 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3077
3078 TosaErrorValidator.evValidateErrorIfs(
3079 self.ser,
3080 validator_fcns,
3081 error_name,
3082 op=op,
3083 input_shape=ifm.shape,
3084 input_dtype=ifm.dtype,
3085 weight_dtype=filter.dtype,
3086 output_shape=result_tens.shape,
3087 output_dtype=result_tens.dtype,
3088 qinfo = qinfo,
3089 result_tensor = result_tens,
3090 input_list=input_list,
3091 output_list=output_list,
3092 num_operands=num_operands,
3093 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003094
Kevin Cheng550ccc52021-03-03 11:21:43 -08003095 self.ser.addOperator(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003096 op['op'], input_list, output_list, None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003097 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003098 return result_tens
3099
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003100 def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
3101 result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
3102
3103 # Invalidate Input/Output list for error if checks.
3104 input_list = [a.name, b.name]
3105 output_list = [result_tens.name]
3106 pCount, cCount = op["operands"]
3107 num_operands = pCount + cCount
3108 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3109
3110 TosaErrorValidator.evValidateErrorIfs(
3111 self.ser,
3112 validator_fcns,
3113 error_name,
3114 op=op,
3115 input_shape=a.shape,
3116 input_dtype=a.dtype,
3117 input2_shape=b.shape,
3118 input2_dtype=b.dtype,
3119 output_shape=result_tens.shape,
3120 output_dtype=result_tens.dtype,
3121 qinfo = qinfo,
3122 result_tensor = result_tens,
3123 input_list=input_list,
3124 output_list=output_list,
3125 num_operands=num_operands,
3126 )
3127
3128 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003129 return result_tens
3130
Matthew Haddond6ce7252021-09-29 15:35:44 +01003131 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
3132 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
3133
3134 # Invalidate Input/Output list for error if checks.
3135 input_list = [a.name]
3136 output_list = [result_tens.name]
3137 pCount, cCount = op["operands"]
3138 num_operands = pCount + cCount
3139 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3140
3141 TosaErrorValidator.evValidateErrorIfs(
3142 self.ser,
3143 validator_fcns,
3144 error_name,
3145 op=op,
3146 axis = axis,
3147 input_shape = a.shape,
3148 output_shape = result_tens.shape,
3149 input_dtype = a.dtype,
3150 output_dtype = result_tens.dtype,
3151 result_tensor = result_tens,
3152 input_list=input_list,
3153 output_list=output_list,
3154 num_operands=num_operands,
3155 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003156
3157 attr = ts.TosaSerializerAttribute()
3158 attr.AxisAttribute(axis)
3159
Matthew Haddond6ce7252021-09-29 15:35:44 +01003160 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003161 return result_tens
3162
3163 def build_clamp(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003164 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003165
3166 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01003167 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07003168
3169 if a.dtype == DType.FLOAT:
3170 attr.ClampAttribute(0, 0, min(v), max(v))
3171 else:
3172 attr.ClampAttribute(min(v), max(v), 0, 0)
3173
Matthew Haddon848efb42021-09-09 12:30:53 +01003174 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003175 return result_tens
3176
3177 def build_leaky_relu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003178 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003179 attr = ts.TosaSerializerAttribute()
3180
3181 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
3182
Matthew Haddon848efb42021-09-09 12:30:53 +01003183 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003184 return result_tens
3185
3186 # Needs an additional type/input
3187 def build_prelu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003188 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003189
Matthew Haddon848efb42021-09-09 12:30:53 +01003190 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003191 return result_tens
3192
Eric Kunzee5e26762020-10-13 16:11:07 -07003193 def build_sigmoid(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003194 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01003195 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003196 return result_tens
3197
3198 def build_tanh(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003199 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01003200 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003201 return result_tens
3202
Matthew Haddon818ab902021-07-27 09:12:49 +01003203 def build_concat(self, op, *a):
Kevin Cheng93a16282021-08-31 16:14:03 -07003204 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01003205
3206 # To store variable length list of input tensors we need to store axis along with it
3207 axis = a[-1]
3208 a = a[:-1]
3209
3210 result_tens = OutputShaper.concatOp(self.ser, axis, *a)
Eric Kunzee5e26762020-10-13 16:11:07 -07003211
3212 attr = ts.TosaSerializerAttribute()
3213 attr.AxisAttribute(axis)
3214
Matthew Haddon818ab902021-07-27 09:12:49 +01003215 input_tensor_names = []
3216 for tensor in a:
3217 input_tensor_names.append(tensor.name)
3218
Matthew Haddon848efb42021-09-09 12:30:53 +01003219 self.ser.addOperator(op['op'], input_tensor_names, [result_tens.name], attr)
3220 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07003221
Matthew Haddone807aae2021-10-11 18:12:58 +01003222 def build_pad(self, op, a, padding, validator_fcns=None, error_name=None, qinfo=None):
3223 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003224
3225 # Need to turn the padding array into a TOSA tensor here.
3226 # This is one of the few tensor operands that does not get
3227 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08003228 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07003229
Matthew Haddone807aae2021-10-11 18:12:58 +01003230 # Invalidate Input/Output list for error if checks.
3231 input_list = [a.name, padding_tens.name]
3232 output_list = [result_tens.name]
3233 pCount, cCount = op["operands"]
3234 num_operands = pCount + cCount
3235 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3236
3237 TosaErrorValidator.evValidateErrorIfs(
3238 self.ser,
3239 validator_fcns,
3240 error_name,
3241 op=op,
3242 input_shape = a.shape,
3243 output_shape = result_tens.shape,
3244 input_dtype = a.dtype,
3245 output_dtype = result_tens.dtype,
3246 pad=padding,
3247 qinfo=qinfo,
3248 result_tensor = result_tens,
3249 input_list=input_list,
3250 output_list=output_list,
3251 num_operands=num_operands,
3252 )
3253
Kevin Cheng550ccc52021-03-03 11:21:43 -08003254 self.ser.addOperator(
Matthew Haddone807aae2021-10-11 18:12:58 +01003255 op['op'], input_list, output_list, None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003256 )
Matthew Haddone86fd342021-09-07 16:12:21 +01003257 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07003258
Matthew Haddone807aae2021-10-11 18:12:58 +01003259 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
3260 result_tens = OutputShaper.reshapeOp(self.ser, self.rng, a, newShape, error_name)
3261
3262 # Invalidate Input/Output list for error if checks.
3263 input_list = [a.name]
3264 output_list = [result_tens.name]
3265 pCount, cCount = op["operands"]
3266 num_operands = pCount + cCount
3267 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3268
3269 TosaErrorValidator.evValidateErrorIfs(
3270 self.ser,
3271 validator_fcns,
3272 error_name,
3273 op=op,
3274 input_shape = a.shape,
3275 output_shape = result_tens.shape,
3276 input_dtype = a.dtype,
3277 output_dtype = result_tens.dtype,
3278 result_tensor = result_tens,
3279 input_list=input_list,
3280 output_list=output_list,
3281 num_operands=num_operands,
3282 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003283
3284 attr = ts.TosaSerializerAttribute()
3285 attr.ReshapeAttribute(newShape)
3286
Matthew Haddone807aae2021-10-11 18:12:58 +01003287 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003288 return result_tens
3289
3290 def build_reverse(self, op, a, axis):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003291 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003292
3293 attr = ts.TosaSerializerAttribute()
3294 attr.AxisAttribute(axis)
3295
Matthew Haddon848efb42021-09-09 12:30:53 +01003296 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003297 return result_tens
3298
Matthew Haddone807aae2021-10-11 18:12:58 +01003299 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
3300 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003301
Kevin Cheng550ccc52021-03-03 11:21:43 -08003302 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07003303
Matthew Haddone807aae2021-10-11 18:12:58 +01003304 # Invalidate Input/Output list for error if checks.
3305 input_list = [a.name, perms_tens.name]
3306 output_list = [result_tens.name]
3307 pCount, cCount = op["operands"]
3308 num_operands = pCount + cCount
3309 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3310
3311 TosaErrorValidator.evValidateErrorIfs(
3312 self.ser,
3313 validator_fcns,
3314 error_name,
3315 op=op,
3316 input_shape = a.shape,
3317 output_shape = result_tens.shape,
3318 perms=perms,
3319 input_dtype = a.dtype,
3320 output_dtype = result_tens.dtype,
3321 result_tensor = result_tens,
3322 input_list=input_list,
3323 output_list=output_list,
3324 num_operands=num_operands,
3325 )
3326
3327
3328 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07003329 return result_tens
3330
Matthew Haddone807aae2021-10-11 18:12:58 +01003331 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
3332 result_tens = OutputShaper.sliceOp(self.ser, self.rng, a, start, size, error_name)
3333
3334 # Invalidate Input/Output list for error if checks.
3335 input_list = [a.name]
3336 output_list = [result_tens.name]
3337 pCount, cCount = op["operands"]
3338 num_operands = pCount + cCount
3339 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3340
3341 TosaErrorValidator.evValidateErrorIfs(
3342 self.ser,
3343 validator_fcns,
3344 error_name,
3345 op=op,
3346 input_shape = a.shape,
3347 output_shape = result_tens.shape,
3348 input_dtype = a.dtype,
3349 output_dtype = result_tens.dtype,
3350 start=start,
3351 size=size,
3352 result_tensor = result_tens,
3353 input_list=input_list,
3354 output_list=output_list,
3355 num_operands=num_operands,
3356 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003357
3358 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01003359 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07003360
Matthew Haddone807aae2021-10-11 18:12:58 +01003361 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003362 return result_tens
3363
3364 def build_tile(self, op, a, multiples):
3365 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
3366
3367 attr = ts.TosaSerializerAttribute()
3368 attr.TileAttribute(multiples)
3369
Matthew Haddon848efb42021-09-09 12:30:53 +01003370 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003371 return result_tens
3372
Kevin Cheng77d0f762020-11-24 10:26:32 -08003373 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07003374
3375 # Create a new indicies tensor
3376 # here with data that doesn't exceed the dimensions of the values tensor
3377
Kevin Cheng550ccc52021-03-03 11:21:43 -08003378 K = values.shape[1] # K
3379 W = self.randInt(
3380 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
3381 ) # W
3382 indicies_arr = np.int32(
3383 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
3384 ) # (N, W)
3385 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003386
Kevin Cheng77d0f762020-11-24 10:26:32 -08003387 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07003388
Matthew Haddon848efb42021-09-09 12:30:53 +01003389 self.ser.addOperator(op['op'], [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003390
3391 return result_tens
3392
Kevin Cheng77d0f762020-11-24 10:26:32 -08003393 def build_scatter(self, op, values_in, input):
3394
3395 # Create a new indicies tensor
3396 # here with data that doesn't exceed the dimensions of the values_in tensor
3397
Kevin Cheng550ccc52021-03-03 11:21:43 -08003398 K = values_in.shape[1] # K
3399 W = input.shape[1] # W
3400 indicies_arr = np.int32(
3401 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
3402 ) # (N, W)
3403 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08003404
3405 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
3406
Kevin Cheng550ccc52021-03-03 11:21:43 -08003407 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003408 op['op'], [values_in.name, indicies.name, input.name], [result_tens.name]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003409 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08003410
3411 return result_tens
3412
Matthew Haddon848efb42021-09-09 12:30:53 +01003413
Kevin Cheng550ccc52021-03-03 11:21:43 -08003414 def build_resize(
3415 self,
3416 op,
3417 input,
3418 mode,
3419 stride,
3420 offset,
3421 shift,
3422 stride_fp,
3423 offset_fp,
3424 output_dims,
3425 input_dtype,
3426 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01003427 validator_fcns,
3428 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003429 ):
3430 result_tens = OutputShaper.resizeOp(
3431 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003432 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003433 input,
3434 mode,
3435 stride,
3436 offset,
3437 shift,
3438 stride_fp,
3439 offset_fp,
3440 output_dims,
3441 input_dtype,
3442 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01003443 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08003444 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003445
Matthew Haddon848efb42021-09-09 12:30:53 +01003446 # Invalidate Input/Output list for error if checks.
3447 input_list = [input.name]
3448 output_list = [result_tens.name]
3449 pCount, cCount = op["operands"]
3450 num_operands = pCount + cCount
3451 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
Matthew Haddone86fd342021-09-07 16:12:21 +01003452
Matthew Haddon848efb42021-09-09 12:30:53 +01003453 TosaErrorValidator.evValidateErrorIfs(
3454 self.ser,
3455 validator_fcns,
3456 error_name,
3457 op=op,
3458 mode=mode,
3459 shift=shift,
3460 input_dtype=input_dtype,
3461 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003462 input_shape=input.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01003463 output_shape=output_dims,
3464 offset=offset,
3465 offset_fp=offset_fp,
3466 stride=stride,
3467 stride_fp=stride_fp,
3468 input_list=input_list,
3469 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003470 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01003471 num_operands=num_operands,
3472 )
Matthew Haddone86fd342021-09-07 16:12:21 +01003473
Eric Kunzee5e26762020-10-13 16:11:07 -07003474 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08003475
Kevin Cheng550ccc52021-03-03 11:21:43 -08003476 attr.ResizeAttribute(
3477 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
3478 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003479
Matthew Haddon848efb42021-09-09 12:30:53 +01003480 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003481 return result_tens
3482
3483 def build_identityn(self, op, val, val2):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003484 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, None)
3485 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, None)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003486 self.ser.addOperator(
3487 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
3488 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003489 return result_tens
3490
Kevin Cheng17e92022021-10-01 14:33:33 -07003491 def build_const(self, op, val):
3492 self.ser.addOutputTensor(val)
3493 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07003494
3495 # Type Conversion
3496 def build_cast(self, op, val, out_dtype):
3497 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01003498 self.ser.addOperator(op['op'], [val.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003499 return result_tens
3500
Matthew Haddonc2025212021-10-08 21:21:05 +01003501 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel, validator_fcns, error_name):
Eric Kunzee5e26762020-10-13 16:11:07 -07003502 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
3503
3504 if per_channel:
3505 nc = val.shape[-1]
3506 else:
3507 nc = 1
3508
3509 in_type_width = self.typeWidth(val.dtype)
3510 out_type_width = self.typeWidth(out_dtype)
3511
Kevin Cheng3a478572021-01-22 17:21:02 -08003512 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003513 input_zp = self.randInt(-128, 128)
3514 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07003515 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003516 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07003517 in_type_width = in_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01003518 elif error_name == ErrorIf.InputZeroPointNotZero:
3519 input_zp = self.randInt(-128, 128)
3520 if input_zp == 0:
3521 input_zp = input_zp + self.rng.integers(1, 10)
3522 in_type_width = in_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003523 else:
3524 input_zp = 0
3525
Kevin Cheng3a478572021-01-22 17:21:02 -08003526 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003527 output_zp = self.randInt(-128, 128)
3528 out_type_width = out_type_width + 1
3529 elif out_dtype == DType.UINT8:
3530 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07003531 out_type_width = out_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01003532 elif error_name == ErrorIf.OutputZeroPointNotZero:
3533 output_zp = self.randInt(-128, 128)
3534 if output_zp == 0:
3535 output_zp = output_zp + self.rng.integers(1, 10)
3536 out_type_width = out_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003537 else:
3538 output_zp = 0
3539
3540 # Calculate scale based on:
3541 # scale = a *(2^output_width)/(2^input_width))
3542
3543 a = np.float32(self.rng.random(size=[nc]))
3544 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
3545
3546 if scale32:
3547 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01003548 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07003549 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
3550 else:
3551 # Cap the scaling at 2^15 - 1 for scale16
3552 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
3553
Kevin Cheng550ccc52021-03-03 11:21:43 -08003554 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07003555
3556 multiplier_arr = np.int32(np.zeros(shape=[nc]))
3557 shift_arr = np.int32(np.zeros(shape=[nc]))
3558
3559 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003560 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
3561 scale_arr[i], scale32
3562 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003563
Kevin Cheng550ccc52021-03-03 11:21:43 -08003564 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07003565
Matthew Haddonc2025212021-10-08 21:21:05 +01003566 # Invalidate Input/Output list for error if checks.
3567 input_list = [val.name]
3568 output_list = [result_tens.name]
3569 pCount, cCount = op["operands"]
3570 num_operands = pCount + cCount
3571 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3572
3573 qinfo = (input_zp, output_zp)
3574 TosaErrorValidator.evValidateErrorIfs(
3575 self.ser,
3576 validator_fcns,
3577 error_name,
3578 op=op,
3579 input_dtype=val.dtype,
3580 output_dtype=out_dtype,
3581 input_shape=val.shape,
3582 qinfo=qinfo,
3583 scale32 = scale32,
3584 double_round = double_round,
3585 input_list=input_list,
3586 output_list=output_list,
3587 result_tensor=result_tens,
3588 num_operands=num_operands,
3589 )
3590
Eric Kunzee5e26762020-10-13 16:11:07 -07003591 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003592 attr.RescaleAttribute(
3593 input_zp,
3594 output_zp,
3595 multiplier_arr,
3596 shift_arr,
3597 scale32,
3598 double_round,
3599 per_channel,
3600 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003601
Matthew Haddonc2025212021-10-08 21:21:05 +01003602 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003603 return result_tens
3604
3605 def build_cond_if_const(self, op, then_tens, else_tens, cond):
3606 # For cond_if with constants, we're supplied with then/else tensors that we ignore
3607 # (except for the generated shap) and the condition. Build Then/Else blocks
3608 # and fill them with const nodes for the body.
3609
3610 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003611 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07003612
3613 # Make then/else tensors
3614 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01003615 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
3616 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003617
3618 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08003619 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07003620
3621 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003622 then_block = "THEN_BLOCK"
3623 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003624 attr = ts.TosaSerializerAttribute()
3625 attr.CondIfAttribute(then_block, else_block)
3626
3627 # Finally, build the op and the two blocks
Matthew Haddon848efb42021-09-09 12:30:53 +01003628 self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003629
3630 self.ser.startBasicBlock(then_block)
3631 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003632 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003633 self.ser.addOutputTensor(then_tens)
3634
3635 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003636 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003637 self.ser.addOutputTensor(else_tens)
3638
3639 return result_tens
3640
3641 def build_cond_if_binary(self, op, a, b, cond):
3642 # For cond_if with a binary op in the then/else blocks, take a and b and
3643 # alternately add or subtract them based on the condition
3644
3645 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003646 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07003647
Kevin Cheng550ccc52021-03-03 11:21:43 -08003648 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003649
3650 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003651 then_block = "THEN_BLOCK"
3652 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003653 attr = ts.TosaSerializerAttribute()
3654 attr.CondIfAttribute(then_block, else_block)
3655
3656 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003657 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003658 op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08003659 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003660
Les Bell6040b4d2021-10-11 12:50:31 +01003661 if a.dtype in (DType.FLOAT, DType.INT32):
3662 then_op, else_op = Op.ADD, Op.SUB
3663 elif a.dtype in (DType.INT8, DType.INT16):
3664 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
3665 else:
3666 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07003667
Les Bell6040b4d2021-10-11 12:50:31 +01003668 for block, op in ((then_block, then_op), (else_block, else_op)):
3669 self.ser.startBasicBlock(block)
3670 self.ser.addInputTensor(a)
3671 self.ser.addInputTensor(b)
3672 tens = self.ser.addOutput(a.shape, a.dtype)
3673 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003674
3675 return result_tens
3676
3677 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003678 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07003679
Kevin Cheng550ccc52021-03-03 11:21:43 -08003680 cond_block = "COND_BLOCK"
3681 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003682
3683 attr = ts.TosaSerializerAttribute()
3684 attr.WhileLoopAttribute(cond_block, body_block)
3685
3686 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003687 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003688 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08003689 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07003690
3691 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003692 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
3693 a_out = self.ser.addIntermediate(a.shape, a.dtype)
3694 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003695
3696 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08003697 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003698 op['op'],
Kevin Cheng550ccc52021-03-03 11:21:43 -08003699 [iter.name, a.name, acc.name],
3700 [iter_out.name, a_out.name, acc_out.name],
3701 attr,
3702 )
Kevin Chengb227ae52021-09-02 13:43:17 -07003703 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07003704
3705 # COND block (input: iter, output: cond_tens )
3706 self.ser.startBasicBlock(cond_block)
3707 self.ser.addInputTensor(iter)
3708 self.ser.addInputTensor(a)
3709 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003710 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
3711 cond_tens = self.ser.addOutput([], DType.BOOL)
3712 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003713
3714 # BODY block (input: a, acc, iter, output: a, acc, iter)
3715 # Note that local intermediate tensors need to be declared here for the outputs
3716 self.ser.startBasicBlock(body_block)
3717 self.ser.addInputTensor(iter)
3718 self.ser.addInputTensor(a)
3719 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003720 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
3721 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
3722 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003723 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
3724 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
3725 self.ser.addOutputTensor(iter_body_out)
3726 self.ser.addOutputTensor(a)
3727 self.ser.addOutputTensor(acc_body_out)
3728
3729 return acc_out
3730
Matthew Haddon1c00b712021-10-01 15:51:03 +01003731 def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
3732 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
3733 default_test_rank_range = range(1, 5)
3734 if not shapeFilter:
3735 shapeFilter = [None]
3736
3737 # Calculate the filters based on what is requested and what the operator allows
3738 rmin, rmax = op["rank"]
3739 if rankFilter is not None:
3740 cleanRankFilter = []
3741 # Ensure rankFilter values are allowed by operator
3742 for rank in rankFilter:
3743 if rank >= rmin and rank <= rmax:
3744 cleanRankFilter.append(rank)
3745 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01003746 # Ensure default behaviour is bounded by default range or by operator,
3747 # whichever is the smaller range of ranks.
3748 opRankRange = range(rmin, rmax + 1)
3749 cleanRankFilter = opRankRange if len(opRankRange) <= len(default_test_rank_range) else default_test_rank_range
Matthew Haddon1c00b712021-10-01 15:51:03 +01003750 else:
3751 cleanRankFilter = range(rmin, rmax + 1)
3752
3753 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003754
Matthew Haddon1c00b712021-10-01 15:51:03 +01003755 if dtypeFilter is not None:
3756 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01003757 # Create list of operator dtypes filtered by requested dtypes
3758 for dtype in dtypes:
3759 if dtype in dtypeFilter or (isinstance(dtype, list) and dtype[0] in dtypeFilter):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003760 cleanDtypeFilter.append(dtype)
3761 else:
3762 cleanDtypeFilter = dtypes
3763
3764 if testType == 'positive':
3765 filterDict = {
3766 'shapeFilter': shapeFilter,
3767 'rankFilter': cleanRankFilter,
3768 'dtypeFilter': cleanDtypeFilter
3769 }
3770 return filterDict
3771 elif testType == 'negative':
Matthew Haddone807aae2021-10-11 18:12:58 +01003772 if validator is not None:
3773 validator_info = validator(check=False, op=op)
3774 else:
3775 return None
3776
Matthew Haddon1c00b712021-10-01 15:51:03 +01003777 error_arguments = validator_info['param_reqs']
3778
3779 #Set parameters as required
3780 if error_arguments['rank'] != None:
3781 rankFilter = error_arguments['rank']
3782 else:
3783 rankFilter = cleanRankFilter
3784
3785 if error_arguments['dtype'] != None:
3786 dtypeFilter = error_arguments['dtype']
3787 else:
3788 dtypeFilter = cleanDtypeFilter
3789
3790 if error_arguments['shape'] != None:
3791 shapeFilter = error_arguments['shape']
3792 else:
3793 shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small
3794
3795 filterDict = {
3796 'shapeFilter': shapeFilter,
3797 'rankFilter': rankFilter,
3798 'dtypeFilter': dtypeFilter
3799 }
3800 return filterDict
3801
3802
Kevin Cheng550ccc52021-03-03 11:21:43 -08003803 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01003804 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08003805 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003806
3807 try:
3808 op = self.TOSA_OP_LIST[opName]
3809 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003810 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003811
3812 # Initialize a new random number generator
3813 self.rng = np.random.default_rng(self.random_seed)
3814
Kevin Cheng550ccc52021-03-03 11:21:43 -08003815 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003816
Eric Kunzee5e26762020-10-13 16:11:07 -07003817 # Test list consists of a tuple of:
3818 # (opName, testNameStr, dtype, shapeList, argumentsList)
3819 testList = []
Matthew Haddon1c00b712021-10-01 15:51:03 +01003820 if testType == 'negative' and "error_if_validators" in op:
3821 error_if_validators = op["error_if_validators"]
3822 else:
3823 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07003824
Matthew Haddon1c00b712021-10-01 15:51:03 +01003825 for validator in error_if_validators:
3826 if validator is not None:
3827 error_name = validator(check=False, op=op)['error_name']
Matthew Haddon1c00b712021-10-01 15:51:03 +01003828 else:
3829 error_name = None
3830
3831 filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator)
Matthew Haddone807aae2021-10-11 18:12:58 +01003832 if filterDict == None:
3833 return []
Matthew Haddon1c00b712021-10-01 15:51:03 +01003834 cleanRankFilter = filterDict['rankFilter']
3835 cleanDtypeFilter = filterDict['dtypeFilter']
3836 cleanShapeFilter = filterDict['shapeFilter']
3837 #print(f"Filters: S {shapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
3838
3839 for r in cleanRankFilter:
Kevin Cheng1533b852021-09-01 12:51:58 -07003840 if opName.startswith("conv3d"):
3841 assert r == 5, "conv3d test must have input rank == 5"
Matthew Haddon1c00b712021-10-01 15:51:03 +01003842 for t in cleanDtypeFilter:
3843 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01003844 # Filter out by rank
3845 if shape is not None and len(shape) != r:
3846 continue
Matthew Haddon74567092021-07-16 15:38:20 +01003847 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003848 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003849
Matthew Haddon74567092021-07-16 15:38:20 +01003850 shapeStr = self.shapeStr(shapeList[0])
3851 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07003852
Matthew Haddon74567092021-07-16 15:38:20 +01003853 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
3854 argList = []
3855 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01003856 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003857 else:
Matthew Haddon74567092021-07-16 15:38:20 +01003858 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07003859
Matthew Haddon74567092021-07-16 15:38:20 +01003860 for argStr, args in argList:
Matthew Haddon1c00b712021-10-01 15:51:03 +01003861 if testType == 'positive':
3862 if argStr:
3863 testStr = "{}_{}_{}_{}".format(
3864 opName, shapeStr, typeStr, argStr
3865 )
3866 else:
3867 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
3868 elif testType == 'negative':
Matthew Haddone86fd342021-09-07 16:12:21 +01003869 if argStr:
3870 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
3871 opName, error_name, shapeStr, typeStr, argStr
3872 )
3873 else:
3874 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003875
3876 testList.append((opName, testStr, t, error_name, shapeList, args))
3877
3878 if testType == 'positive':
3879 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
3880 if "invalid_test_validators" in op:
3881 invalid_test_validators = op["invalid_test_validators"]
3882 clean_testList = []
3883 for test in testList:
3884 for validator_fcn in invalid_test_validators:
3885 remove_test = False
3886 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
3887 remove_test = True
3888 if not remove_test:
3889 clean_testList.append(test)
3890 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07003891
3892 return testList
3893
Matthew Haddone86fd342021-09-07 16:12:21 +01003894
3895 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07003896 try:
3897 op = self.TOSA_OP_LIST[opName]
3898 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003899 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003900
3901 # Create a serializer
3902 self.createSerializer(opName, testStr)
3903
Kevin Cheng550ccc52021-03-03 11:21:43 -08003904 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003905 if "error_if_validators" in op:
3906 error_if_validators = op["error_if_validators"]
3907 else:
3908 error_if_validators = None
3909
Kevin Cheng550ccc52021-03-03 11:21:43 -08003910 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003911 num_operands = pCount + cCount
3912
3913 if isinstance(dtype_or_dtypeList, list):
3914 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07003915 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01003916 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003917 else:
3918 dtypeList = [dtype_or_dtypeList] * (num_operands)
3919
Kevin Cheng93a16282021-08-31 16:14:03 -07003920 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01003921 assert (
3922 len(shapeList) == num_operands
3923 ), "shapeList length {} must match number of operands {}".format(
3924 len(shapeList), num_operands
3925 )
3926 assert (
3927 len(dtypeList) == num_operands
3928 ), "dtypeList length {} must match number of operands {}".format(
3929 len(dtypeList), num_operands
3930 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003931
3932 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003933 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003934 except KeyError:
3935 qgen = None
3936
3937 # Build the random tensor operands and the test
3938 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08003939
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003940 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003941
3942 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003943 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003944 else:
3945 qinfo = None
3946
3947 try:
3948 if error_if_validators is None:
3949 if qinfo is not None:
3950 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
3951 else:
3952 resultName = build_fcn(self, op, *tens, *testArgs)
3953 else:
3954 if qinfo is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003955 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name, qinfo)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003956 else:
3957 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name)
3958 except TypeError as e:
3959 print(
3960 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
3961 build_fcn, tens, testArgs
3962 )
3963 )
3964 raise e
3965
3966 if resultName is None:
3967 print("Invalid ERROR_IF tests created")
3968
3969 # Save the serialized test
3970 self.serialize("test")
3971
3972
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003973 def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003974 pCount, cCount = op["operands"]
3975
3976 tens = []
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003977 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name == None:
Jeremy Johnsonef509a42021-09-07 13:59:47 +01003978 # Make sure the operation does not cause value saturation - where
3979 # the number wraps due to limited number of bits to store the answer
3980 assert (
3981 pCount == 2 and cCount == 0
3982 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01003983 placeholders = []
3984 add = (op["op"] == Op.ADD)
3985 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
3986 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
3987 if add:
3988 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
3989 else:
3990 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
3991
3992 # Work out the saturation limits
3993 max_i32 = (1 << 31)-1
3994 min_i32 = -(1 << 31)
3995 max_arr = np.full(shapeList[1], max_i32)
3996 min_arr = np.full(shapeList[1], min_i32)
3997
3998 # Find how much values exceed the maximum/minimums
3999 sat_max_arr = np.maximum(res_arr - max_arr, 0)
4000 sat_min_arr = np.minimum(res_arr - min_arr, 0)
4001
4002 if not add:
4003 # Swap saturation values and negate values as we need to perform opposite operations
4004 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
4005
4006 # Create new array of unsaturated values by clipping values as needed
4007 b_unsat_arr = b_arr
4008 if (sat_max_arr != 0).any():
4009 # Clip values that cause saturation
4010 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
4011 # Reduce axes in unsaturated tensor to match original tensor
4012 for axis, dim in enumerate(b_arr.shape):
4013 if dim != b_unsat_arr.shape[axis]:
4014 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
4015 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
4016
4017 if (sat_min_arr != 0).any():
4018 # Clip values that cause saturation
4019 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
4020 # Reduce axes in unsaturated tensor to match original tensor
4021 for axis, dim in enumerate(b_arr.shape):
4022 if dim != b_unsat_arr.shape[axis]:
4023 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
4024 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
4025
4026 placeholders.append(
4027 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
4028 )
4029 placeholders.append(
4030 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
4031 )
4032
4033 tens.extend(placeholders)
4034 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
4035 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004036 assert (
4037 pCount == 2 and cCount == 0
4038 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08004039
4040 placeholders = []
4041 for idx, shape in enumerate(shapeList[:]):
4042 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07004043 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08004044 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07004045 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08004046 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07004047 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08004048 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
4049 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004050 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08004051 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004052 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07004053 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08004054
4055 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01004056 elif op["op"] == Op.SELECT:
4057 # Set datatype of condition tensor to boolean
4058 dtypeList[0] = DType.BOOL
4059 tens.extend(
4060 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
4061 )
4062 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004063 elif op["op"] == Op.INTDIV and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004064 assert (
4065 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01004066 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004067
4068 placeholders = []
4069
Matthew Haddon459443c2021-08-23 16:43:13 +01004070 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004071 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07004072 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004073 while True:
4074 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
4075 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
4076
4077 if (divisor_arr == 0).any():
4078 continue
4079
Kevin Cheng47315e12021-05-13 17:41:28 -07004080 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004081 continue
4082
4083 break
4084
4085 placeholders.append(
4086 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
4087 )
4088 placeholders.append(
4089 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
4090 )
4091
4092 tens.extend(placeholders)
4093 elif op["op"] == Op.MUL:
4094 assert (
4095 pCount == 2 and cCount == 0
4096 ), "Op.MUL must have 2 placeholders, 0 consts"
4097
4098 if dtypeList[0] == DType.FLOAT:
4099 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
4100 else:
4101 placeholders = []
4102
4103 # Make sure multiply result in int32 range
4104 shift = testArgs[0]
4105 if dtypeList[0] == DType.INT8:
4106 num_bits = 8
4107 elif dtypeList[0] == DType.INT16:
4108 num_bits = 16
4109 elif dtypeList[0] == DType.INT32:
4110 num_bits = 32
4111 else:
4112 raise Exception("OpMul: invalid input dtype")
4113
4114 for idx, shape in enumerate(shapeList[:]):
4115 low = -(2 ** (num_bits - 1))
4116 high = (2 ** (num_bits - 1)) - 1
4117
4118 a_arr = np.int32(
4119 self.rng.integers(low=low, high=high, size=shapeList[0])
4120 )
4121 b_arr = np.int32(
4122 self.rng.integers(low=low, high=high, size=shapeList[1])
4123 )
4124
4125 i = 0
4126 while True:
4127
4128 a_arr_64 = a_arr.astype(np.int64)
4129 b_arr_64 = b_arr.astype(np.int64)
4130
4131 if shift > 0:
4132 rounding = 1 << (shift - 1)
4133 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
4134 else:
4135 result_arr = a_arr_64 * b_arr_64
4136
4137 if (result_arr > -(2 ** 31)).all() and (
4138 result_arr <= ((2 ** 31) - 1)
4139 ).all():
4140 break
4141
4142 i = i + 1
4143 a_arr = a_arr // 2
4144 b_arr = b_arr // 2
4145
4146 placeholders.append(
4147 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
4148 )
4149 placeholders.append(
4150 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
4151 )
4152
4153 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01004154 elif op["op"] == Op.CONCAT:
4155 count = len(shapeList) - self.args.num_const_inputs_concat
4156 if count < 1:
4157 count = 1
4158 if self.args.num_const_inputs_concat == 0:
4159 count = len(shapeList)
4160
4161 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
4162 tens.extend(
4163 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
4164 )
4165 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08004166 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07004167 tens.extend(
4168 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
4169 )
4170 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07004171
Matthew Haddon1c00b712021-10-01 15:51:03 +01004172 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07004173
4174 def createDynamicOpLists(self):
4175
4176 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07004177 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004178
Kevin Cheng1533b852021-09-01 12:51:58 -07004179 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004180 testName = "conv2d_{}x{}".format(k[0], k[1])
4181 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
4182 self.TOSA_OP_LIST[testName]["filter"] = k
4183 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07004184
Kevin Cheng550ccc52021-03-03 11:21:43 -08004185 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
4186 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
4187 "depthwise_conv2d_TEMPLATE"
4188 ].copy()
4189 self.TOSA_OP_LIST[testName]["filter"] = k
4190 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07004191
Kevin Cheng550ccc52021-03-03 11:21:43 -08004192 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
4193 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
4194 "transpose_conv2d_TEMPLATE"
4195 ].copy()
4196 self.TOSA_OP_LIST[testName]["filter"] = k
4197 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07004198
Kevin Cheng1533b852021-09-01 12:51:58 -07004199 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
4200 for k in KERNELS_3D:
4201 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
4202 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
4203 self.TOSA_OP_LIST[testName]["filter"] = k
4204 self.TOSA_OP_LIST[testName]["template"] = False
4205
Eric Kunzee5e26762020-10-13 16:11:07 -07004206 # Delete any templates after having created any dynamic ops
4207 # This is a two-pass operation because it's bad practice to delete
4208 # keys from dictionaries while iterating
4209 keyList = []
4210 for k in self.TOSA_OP_LIST:
4211 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004212 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07004213 keyList.append(k)
4214 continue
4215 except KeyError:
4216 pass
4217
4218 for k in keyList:
4219 del self.TOSA_OP_LIST[k]
4220
4221 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004222 """Fill in default fields for ops if they aren't already specified.
4223 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07004224 for op in self.TOSA_OP_LIST:
4225
4226 # Required fields
4227 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004228 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004229 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004230 raise Exception(
4231 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
4232 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004233
4234 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004235 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004236 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004237 raise Exception(
4238 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
4239 op
4240 )
4241 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004242
4243 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004244 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004245 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004246 raise Exception(
4247 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
4248 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004249
4250 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004251 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004252 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004253 raise Exception(
4254 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
4255 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004256
4257 # Put in default rank range, if missing
4258 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004259 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004260 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004261 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07004262
4263 # Tensor operator list
4264 # 'op': op name
4265 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08004266 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
4267 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07004268 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
4269 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08004270 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07004271
Kevin Cheng550ccc52021-03-03 11:21:43 -08004272 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
4273 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07004274
Kevin Cheng550ccc52021-03-03 11:21:43 -08004275 TYPE_BOOL = [DType.BOOL]
4276 TYPE_FI32 = [DType.FLOAT, DType.INT32]
4277 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
4278 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07004279
Kevin Cheng550ccc52021-03-03 11:21:43 -08004280 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07004281
Kevin Cheng1533b852021-09-01 12:51:58 -07004282 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07004283 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07004284 [DType.INT8, DType.INT8, DType.INT32],
4285 [DType.INT16, DType.INT8, DType.INT48],
4286 DType.FLOAT,
4287 ]
4288
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01004289 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07004290
4291 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08004292 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004293 "argmax": {
4294 "op": Op.ARGMAX,
4295 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004296 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004297 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4298 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004299 "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evArgmaxOutputRankMismatch,
4300 TosaErrorValidator.evArgmaxOutputShapeMismatch, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
4301 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004302 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004303 "avg_pool2d": {
4304 "op": Op.AVG_POOL2D,
4305 "operands": (1, 0),
4306 "rank": (4, 4),
4307 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
4308 "qgen": TosaQuantGen.qgUnary,
4309 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004310 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
4311 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
4312 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
4313 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
4314 TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08004315 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004316 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08004317 "conv2d_TEMPLATE": {
4318 "op": Op.CONV2D,
4319 "operands": (1, 2),
4320 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01004321 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004322 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004323 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01004324 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004325 "template": True,
4326 },
Kevin Cheng1533b852021-09-01 12:51:58 -07004327 # Templated operator. Filled in by createDynamicOpLists
4328 "conv3d_TEMPLATE": {
4329 "op": Op.CONV3D,
4330 "operands": (1, 2),
4331 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01004332 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07004333 "qgen": TosaQuantGen.qgConv,
4334 "types": TYPE_CONV,
4335 "template": True,
4336 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004337 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08004338 "depthwise_conv2d_TEMPLATE": {
4339 "op": Op.DEPTHWISE_CONV2D,
4340 "operands": (1, 2),
4341 "filter": [1, 1],
4342 "rank": (4, 4),
4343 "build_fcn": (
4344 build_depthwise_conv2d,
4345 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01004346 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004347 ),
4348 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004349 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01004350 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004351 "template": True,
4352 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004353 "fully_connected": {
4354 "op": Op.FULLY_CONNECTED,
4355 "operands": (1, 2),
4356 "rank": (2, 2),
4357 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
4358 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004359 "types": TYPE_CONV,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004360 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evWrongRank,
4361 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004362 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004363 "matmul": {
4364 "op": Op.MATMUL,
4365 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07004366 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08004367 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
4368 "qgen": TosaQuantGen.qgMatmul,
4369 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004370 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
4371 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004372 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004373 "max_pool2d": {
4374 "op": Op.MAX_POOL2D,
4375 "operands": (1, 0),
4376 "rank": (4, 4),
4377 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
4378 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004379 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
4380 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
4381 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
4382 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08004383 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004384 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08004385 "transpose_conv2d_TEMPLATE": {
4386 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07004387 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004388 "rank": (4, 4),
4389 "build_fcn": (
4390 build_transpose_conv2d,
4391 TosaTensorGen.tgTransposeConv2D,
4392 TosaArgGen.agTransposeConv2D,
4393 ),
4394 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004395 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01004396 "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004397 "template": True,
4398 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004399 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08004400 "clamp": {
4401 "op": Op.CLAMP,
4402 "operands": (1, 0),
4403 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
4404 "types": TYPE_NARROW_INT_FP,
4405 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004406 "sigmoid": {
4407 "op": Op.SIGMOID,
4408 "operands": (1, 0),
4409 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
4410 "types": TYPE_FP,
4411 },
4412 "tanh": {
4413 "op": Op.TANH,
4414 "operands": (1, 0),
4415 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
4416 "types": TYPE_FP,
4417 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004418 # Elementwise Binary Operators
4419 "add": {
4420 "op": Op.ADD,
4421 "operands": (2, 0),
4422 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4423 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004424 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4425 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004426 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004427 "arithmetic_right_shift": {
4428 "op": Op.ARITHMETIC_RIGHT_SHIFT,
4429 "operands": (2, 0),
4430 "build_fcn": (
4431 build_arithmetic_right_shift,
4432 TosaTensorGen.tgBroadcastFuzz,
4433 TosaArgGen.agArithmeticRightShift,
4434 ),
4435 "types": TYPE_INT,
4436 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004437 "bitwise_and": {
4438 "op": Op.BITWISE_AND,
4439 "operands": (2, 0),
4440 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4441 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004442 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4443 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004444 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004445 "bitwise_or": {
4446 "op": Op.BITWISE_OR,
4447 "operands": (2, 0),
4448 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4449 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004450 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4451 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004452 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004453 "bitwise_xor": {
4454 "op": Op.BITWISE_XOR,
4455 "operands": (2, 0),
4456 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4457 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004458 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4459 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004460 },
Matthew Haddon459443c2021-08-23 16:43:13 +01004461 "intdiv": {
4462 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004463 "operands": (2, 0),
4464 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4465 "types": [DType.INT32],
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004466 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4467 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004468 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004469 "logical_and": {
4470 "op": Op.LOGICAL_AND,
4471 "operands": (2, 0),
4472 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4473 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004474 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4475 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004476 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004477 "logical_left_shift": {
4478 "op": Op.LOGICAL_LEFT_SHIFT,
4479 "operands": (2, 0),
4480 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4481 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004482 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4483 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004484 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004485 "logical_right_shift": {
4486 "op": Op.LOGICAL_RIGHT_SHIFT,
4487 "operands": (2, 0),
4488 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4489 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004490 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4491 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004492 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004493 "logical_or": {
4494 "op": Op.LOGICAL_OR,
4495 "operands": (2, 0),
4496 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4497 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004498 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4499 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004500 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004501 "logical_xor": {
4502 "op": Op.LOGICAL_XOR,
4503 "operands": (2, 0),
4504 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4505 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004506 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4507 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004508 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004509 "maximum": {
4510 "op": Op.MAXIMUM,
4511 "operands": (2, 0),
4512 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4513 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004514 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4515 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004516 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004517 "minimum": {
4518 "op": Op.MINIMUM,
4519 "operands": (2, 0),
4520 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4521 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004522 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4523 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004524 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004525 "mul": {
4526 "op": Op.MUL,
4527 "operands": (2, 0),
4528 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
4529 "types": TYPE_INT_FP,
4530 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004531 "pow": {
4532 "op": Op.POW,
4533 "operands": (2, 0),
4534 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
4535 "types": TYPE_FP,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004536 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4537 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004538 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004539 "sub": {
4540 "op": Op.SUB,
4541 "operands": (2, 0),
4542 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4543 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004544 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4545 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004546 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004547 "table": {
4548 "op": Op.TABLE,
4549 # Use the automatic generation functions to create the input array
4550 # but create the table tensor in the build function, as it may be
4551 # a different type from the input
4552 "operands": (1, 0),
4553 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004554 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08004555 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004556 # Elementwise Unary operators
4557 "abs": {
4558 "op": Op.ABS,
4559 "operands": (1, 0),
4560 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4561 "types": TYPE_FI32,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004562 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4563 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004564 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004565 "bitwise_not": {
4566 "op": Op.BITWISE_NOT,
4567 "operands": (1, 0),
4568 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4569 "types": TYPE_INT,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004570 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4571 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004572 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004573 "ceil": {
4574 "op": Op.CEIL,
4575 "operands": (1, 0),
4576 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4577 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004578 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4579 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004580 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004581 "clz": {
4582 "op": Op.CLZ,
4583 "operands": (1, 0),
4584 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4585 "types": [DType.INT32],
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004586 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4587 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004588 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004589 "exp": {
4590 "op": Op.EXP,
4591 "operands": (1, 0),
4592 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4593 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004594 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4595 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004596 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004597 "floor": {
4598 "op": Op.FLOOR,
4599 "operands": (1, 0),
4600 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4601 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004602 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4603 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004604 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004605 "log": {
4606 "op": Op.LOG,
4607 "operands": (1, 0),
4608 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4609 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004610 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4611 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004612 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004613 "logical_not": {
4614 "op": Op.LOGICAL_NOT,
4615 "operands": (1, 0),
4616 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4617 "types": TYPE_BOOL,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004618 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4619 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004620 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004621 "negate": {
4622 "op": Op.NEGATE,
4623 "operands": (1, 0),
4624 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4625 "qgen": TosaQuantGen.qgUnary,
4626 "types": TYPE_INT_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004627 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
4628 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
4629 TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004630 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004631 "reciprocal": {
4632 "op": Op.RECIPROCAL,
4633 "operands": (1, 0),
4634 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4635 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004636 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4637 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004638 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004639 "rsqrt": {
4640 "op": Op.RSQRT,
4641 "operands": (1, 0),
4642 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4643 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004644 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4645 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004646 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004647 # Elementwise Ternary operators
4648 "select": {
4649 "op": Op.SELECT,
4650 "operands": (3, 0),
4651 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
4652 "types": TYPE_FIB,
4653 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004654 # Comparison operators
4655 "equal": {
4656 "op": Op.EQUAL,
4657 "operands": (2, 0),
4658 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4659 "types": TYPE_FI32,
4660 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004661 "greater_equal": {
4662 "op": Op.GREATER_EQUAL,
4663 "operands": (2, 0),
4664 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4665 "types": TYPE_FI32,
4666 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004667 "greater": {
4668 "op": Op.GREATER,
4669 "operands": (2, 0),
4670 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4671 "types": TYPE_FI32,
4672 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004673 # Reduction operators
4674 "reduce_all": {
4675 "op": Op.REDUCE_ALL,
4676 "operands": (1, 0),
4677 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4678 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004679 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4680 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4681 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004682 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004683 "reduce_any": {
4684 "op": Op.REDUCE_ANY,
4685 "operands": (1, 0),
4686 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4687 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004688 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4689 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4690 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004691 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004692 "reduce_max": {
4693 "op": Op.REDUCE_MAX,
4694 "operands": (1, 0),
4695 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4696 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004697 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4698 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4699 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004700 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004701 "reduce_min": {
4702 "op": Op.REDUCE_MAX,
4703 "operands": (1, 0),
4704 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4705 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004706 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4707 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4708 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004709 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004710 "reduce_product": {
4711 "op": Op.REDUCE_PRODUCT,
4712 "operands": (1, 0),
4713 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4714 "types": TYPE_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004715 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4716 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4717 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004718 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004719 "reduce_sum": {
4720 "op": Op.REDUCE_SUM,
4721 "operands": (1, 0),
4722 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4723 "types": TYPE_FI32,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004724 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4725 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4726 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004727 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004728 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004729 "concat": {
4730 "op": Op.CONCAT,
4731 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01004732 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004733 "types": TYPE_FIB,
4734 },
4735 "pad": {
4736 "op": Op.PAD,
4737 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01004738 "rank": (1, 5),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004739 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
4740 "qgen": TosaQuantGen.qgPad,
4741 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004742 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evPadSmallerZero,
4743 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004744 },
4745 "reshape": {
4746 "op": Op.RESHAPE,
4747 "operands": (1, 0),
4748 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
4749 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004750 "error_if_validators": (TosaErrorValidator.evTensorSizeInputOutputMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4751 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004752 },
4753 "reverse": {
4754 "op": Op.REVERSE,
4755 "operands": (1, 0),
4756 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4757 "types": TYPE_FIB,
4758 },
4759 "slice": {
4760 "op": Op.SLICE,
4761 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01004762 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004763 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
4764 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004765 "error_if_validators": (TosaErrorValidator.evStartSmallerZero, TosaErrorValidator.evSizeSmallerEqualZero, TosaErrorValidator.evStartSizeOutsideBounds,
4766 TosaErrorValidator.evSizeOutputShapeMismatch, TosaErrorValidator.evInputSizeStartLengthMismatch, TosaErrorValidator.evWrongRank,
4767 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004768 },
4769 "tile": {
4770 "op": Op.TILE,
4771 "operands": (1, 0),
4772 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
4773 "types": TYPE_FIB,
4774 },
4775 "transpose": {
4776 "op": Op.TRANSPOSE,
4777 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01004778 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004779 "build_fcn": (
4780 build_transpose,
4781 TosaTensorGen.tgBasic,
4782 TosaArgGen.agTranspose,
4783 ),
4784 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004785 "error_if_validators": (TosaErrorValidator.evIndexOutsideBounds, TosaErrorValidator.evIndexUsedTwice, TosaErrorValidator.evWrongRank,
4786 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004787 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004788 # Data nodes
4789 "const": {
4790 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004791 "operands": (0, 1),
4792 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08004793 "types": TYPE_FIB,
4794 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004795 "identity": {
4796 "op": Op.IDENTITY,
4797 "operands": (1, 0),
4798 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4799 "types": TYPE_FIB,
4800 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004801 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004802 "gather": {
4803 "op": Op.GATHER,
4804 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4805 "operands": (1, 0),
4806 "rank": (3, 3),
4807 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
4808 "types": TYPE_INT_FP,
4809 },
4810 "scatter": {
4811 "op": Op.SCATTER,
4812 # Only specify 'values_in' tensor here.
4813 #'indices' and 'input' are generated in op building stage
4814 "operands": (2, 0),
4815 "rank": (3, 3),
4816 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
4817 "types": TYPE_INT_FP,
4818 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004819 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004820 "resize": {
4821 "op": Op.RESIZE,
4822 "operands": (1, 0),
4823 "rank": (4, 4),
4824 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
4825 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01004826 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
4827 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
4828 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
Matthew Haddon848efb42021-09-09 12:30:53 +01004829 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004830 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
4831 TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004832 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004833 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004834 "cast": {
4835 "op": Op.CAST,
4836 "operands": (1, 0),
4837 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
4838 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
4839 },
4840 "rescale": {
4841 "op": Op.RESCALE,
4842 "operands": (1, 0),
Matthew Haddonc2025212021-10-08 21:21:05 +01004843 "rank": (1,4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004844 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004845 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Matthew Haddonc2025212021-10-08 21:21:05 +01004846 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, TosaErrorValidator.evScaleTrue,
4847 TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4848 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004849 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004850 # Custom
4851 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004852 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004853 # Two varients of cond_if, one that generates one of two constant tensors (no
4854 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4855 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004856 "cond_if_const": {
4857 "op": Op.COND_IF,
4858 "operands": (0, 2),
4859 "build_fcn": (
4860 build_cond_if_const,
4861 TosaTensorGen.tgBasic,
4862 TosaArgGen.agCondIf,
4863 ),
4864 "types": [DType.BOOL],
4865 },
4866 "cond_if_binary": {
4867 "op": Op.COND_IF,
4868 "operands": (2, 0),
4869 "build_fcn": (
4870 build_cond_if_binary,
4871 TosaTensorGen.tgBasic,
4872 TosaArgGen.agCondIf,
4873 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004874 "types": TYPE_INT_FP,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004875 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004876 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004877 "while_loop": {
4878 "op": Op.WHILE_LOOP,
4879 "operands": (0, 1),
4880 "build_fcn": (
4881 build_while_loop,
4882 TosaTensorGen.tgBasic,
4883 TosaArgGen.agWhileLoop,
4884 ),
4885 "types": [DType.INT32],
4886 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004887 }
4888
Kevin Cheng550ccc52021-03-03 11:21:43 -08004889
Eric Kunzee5e26762020-10-13 16:11:07 -07004890class OutputShaper:
4891 # Methods in this class compute the expected output shape and datatype
4892 # for common classes of operations
4893 def __init__(self):
4894 pass
4895
4896 # These methods return arguments that can be used for
4897 # creating a new output tensor
4898 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004899 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4900 if error_name != ErrorIf.RankMismatch:
4901 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004902 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004903
4904 shape = []
4905 for i in range(len(a.shape)):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004906 if a.shape[i] == 1 and error_name == None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004907 shape.append(b.shape[i])
4908 else:
4909 shape.append(a.shape[i])
4910
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004911 if error_name == ErrorIf.WrongOutputType:
4912 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4913 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4914 outputDType = rng.choice(wrong_dtypes)
4915 else:
4916 outputDType = a.dtype
4917
4918 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004919
4920 @staticmethod
4921 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004922 assert len(a.shape) == len(b.shape)
4923 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004924
4925 shape = []
4926 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004927 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004928 shape.append(a.shape[i])
4929
Kevin Cheng550ccc52021-03-03 11:21:43 -08004930 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004931
4932 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004933 def unaryOp(ser, rng, a, error_name=None):
4934 if error_name == ErrorIf.WrongOutputType:
4935 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4936 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4937 outputDType = rng.choice(wrong_dtypes)
4938 else:
4939 outputDType = a.dtype
4940
4941 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004942
4943 @staticmethod
4944 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004945 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
4946 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004947
4948 shape = []
4949 for i in range(len(a.shape)):
4950 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4951
Kevin Cheng550ccc52021-03-03 11:21:43 -08004952 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004953
4954 @staticmethod
4955 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004956 assert len(a.shape) == len(b.shape)
4957 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004958
4959 # Do broadcast
4960 shape = []
4961 for i in range(len(a.shape)):
4962 if a.shape[i] == 1:
4963 shape.append(b.shape[i])
4964 else:
4965 shape.append(a.shape[i])
4966
4967 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08004968 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07004969
4970 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004971 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004972 shape = a.shape.copy()
Matthew Haddond6ce7252021-09-29 15:35:44 +01004973 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ShapeOfAxisNotOne]:
4974 shape[axis] = 1
4975 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4976 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004977
Matthew Haddond6ce7252021-09-29 15:35:44 +01004978 if error_name == ErrorIf.WrongOutputType:
4979 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4980 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4981 outputDType = rng.choice(wrong_dtypes)
4982 else:
4983 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004984
Matthew Haddond6ce7252021-09-29 15:35:44 +01004985 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004986
4987 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004988 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004989 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004990
4991 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4992 del shape[axis]
4993
4994 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4995 remove = rng.choice([True, False])
4996 if remove and len(shape) > 1:
4997 del shape[0]
4998 else:
4999 shape.append(1)
5000 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5001 for i in range(len(shape)):
5002 shape[i] = shape[i] + rng.integers(1, 10)
5003
5004 if error_name == ErrorIf.WrongOutputType:
5005 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5006 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5007 outputDType = rng.choice(wrong_dtypes)
5008 else:
5009 outputDType = DType.INT32
5010
5011 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005012
5013 @staticmethod
5014 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
5015
5016 # IFM: NHWC
5017 # Filter: OHWI
5018 # OFM: NHWC
5019
5020 if len(padding) == 2:
5021 # Expand padding to 4 parameters in the case of transpose_conv2d
5022 # From H,W to T,B,L,R
5023 padding = [padding[0], padding[0], padding[1], padding[1]]
5024
Kevin Cheng550ccc52021-03-03 11:21:43 -08005025 h = (
5026 ifm.shape[1]
5027 - filter.shape[1]
5028 - (filter.shape[1] - 1) * (dilations[0] - 1)
5029 + padding[0]
5030 + padding[1]
5031 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005032
Kevin Cheng550ccc52021-03-03 11:21:43 -08005033 w = (
5034 ifm.shape[2]
5035 - filter.shape[2]
5036 - (filter.shape[2] - 1) * (dilations[1] - 1)
5037 + padding[2]
5038 + padding[3]
5039 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005040
Eric Kunzee5e26762020-10-13 16:11:07 -07005041 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5042
Kevin Cheng3a478572021-01-22 17:21:02 -08005043 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005044 out_dtype = DType.INT32
5045 elif ifm.dtype == DType.INT16:
5046 out_dtype = DType.INT48
5047 elif ifm.dtype == DType.FLOAT:
5048 out_dtype = DType.FLOAT
5049 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005050 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005051
Kevin Cheng550ccc52021-03-03 11:21:43 -08005052 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005053
5054 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -07005055 def conv3dOp(ser, ifm, filter, strides, padding, dilations):
5056
5057 # IFM: NDHWC
5058 # Filter: ODHWI
5059 # OFM: NDHWC
5060
5061 d = (
5062 ifm.shape[1]
5063 - filter.shape[1]
5064 - (filter.shape[1] - 1) * (dilations[0] - 1)
5065 + padding[0]
5066 + padding[1]
5067 ) // strides[0] + 1
5068
5069 h = (
5070 ifm.shape[2]
5071 - filter.shape[2]
5072 - (filter.shape[2] - 1) * (dilations[1] - 1)
5073 + padding[2]
5074 + padding[3]
5075 ) // strides[1] + 1
5076
5077 w = (
5078 ifm.shape[3]
5079 - filter.shape[3]
5080 - (filter.shape[3] - 1) * (dilations[2] - 1)
5081 + padding[4]
5082 + padding[5]
5083 ) // strides[2] + 1
5084
5085 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5086
5087 if ifm.dtype == DType.INT8:
5088 out_dtype = DType.INT32
5089 elif ifm.dtype == DType.INT16:
5090 out_dtype = DType.INT48
5091 elif ifm.dtype == DType.FLOAT:
5092 out_dtype = DType.FLOAT
5093 else:
5094 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
5095
5096 return ser.addOutput(ofm_shape, out_dtype)
5097
5098 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -07005099 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
5100 # IFM: NHWC
5101 # Filter: HWCM
5102 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08005103 h = (
5104 ifm.shape[1]
5105 - filter.shape[0]
5106 - (filter.shape[0] - 1) * (dilations[0] - 1)
5107 + padding[0]
5108 + padding[1]
5109 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005110
Kevin Cheng550ccc52021-03-03 11:21:43 -08005111 w = (
5112 ifm.shape[2]
5113 - filter.shape[1]
5114 - (filter.shape[1] - 1) * (dilations[1] - 1)
5115 + padding[2]
5116 + padding[3]
5117 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005118
Eric Kunzee5e26762020-10-13 16:11:07 -07005119 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5120
Kevin Cheng3a478572021-01-22 17:21:02 -08005121 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005122 out_dtype = DType.INT32
5123 elif ifm.dtype == DType.INT16:
5124 out_dtype = DType.INT48
5125 elif ifm.dtype == DType.FLOAT:
5126 out_dtype = DType.FLOAT
5127 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005128 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005129
Kevin Cheng550ccc52021-03-03 11:21:43 -08005130 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005131
5132 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005133 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005134 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005135 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005136 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005137 h = 1
5138 w = 1
5139 else:
5140 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
5141 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
5142
5143 if error_name == ErrorIf.PoolingOutputShapeMismatch:
5144 choices = [1, 2, 3, 4, 5]
5145 h = h + rng.choice(choices)
5146 w = w + rng.choice(choices)
Eric Kunzee5e26762020-10-13 16:11:07 -07005147
Eric Kunzee5e26762020-10-13 16:11:07 -07005148 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005149
5150 if error_name == ErrorIf.WrongOutputType:
5151 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5152 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5153 outputDType = rng.choice(wrong_dtypes)
5154 else:
5155 outputDType = ifm.dtype
5156
5157 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005158
5159 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005160 def fullyConnectedOp(ser, rng, input, filter, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005161 # input: N, IC
5162 # filter: OC, IC
5163 # output: N, OC
5164
5165 output_shape = [input.shape[0], filter.shape[0]]
5166
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005167 if error_name == ErrorIf.WrongOutputType:
5168 if input.dtype == DType.INT8:
5169 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
5170 elif input.dtype == DType.INT16:
5171 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
5172 elif input.dtype == DType.FLOAT:
5173 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
5174 out_dtype = rng.choice(a=incorrect_types)
5175 elif input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005176 out_dtype = DType.INT32
5177 elif input.dtype == DType.INT16:
5178 out_dtype = DType.INT48
5179 elif input.dtype == DType.FLOAT:
5180 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005181 elif error_name == ErrorIf.WrongInputType:
5182 # Pick some potentially correct output dtype if input type is incorrect
5183 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005184 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005185 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005186
Kevin Cheng550ccc52021-03-03 11:21:43 -08005187 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005188
5189 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005190 def matmulOp(ser, rng, a, b, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005191 # a: N, H, C
5192 # b: N, C, W
5193 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005194
Kevin Cheng2d60f002021-06-09 14:18:32 -07005195 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005196
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005197 if error_name == ErrorIf.WrongOutputType:
5198 if a.dtype == DType.INT8:
5199 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
5200 elif a.dtype == DType.INT16:
5201 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
5202 elif a.dtype == DType.FLOAT:
5203 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
5204 out_dtype = rng.choice(a=incorrect_types)
5205 elif a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005206 out_dtype = DType.INT32
5207 elif a.dtype == DType.INT16:
5208 out_dtype = DType.INT48
5209 elif a.dtype == DType.FLOAT:
5210 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005211 elif error_name == ErrorIf.WrongInputType:
5212 # Pick some potentially correct output dtype if input type is incorrect
5213 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005214 else:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005215 raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005216
Kevin Cheng550ccc52021-03-03 11:21:43 -08005217 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005218
5219 @staticmethod
Matthew Haddon818ab902021-07-27 09:12:49 +01005220 def concatOp(ser, axis, *a):
5221 input1 = a[0]
5222 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005223
Matthew Haddon818ab902021-07-27 09:12:49 +01005224 output_shape = input1.shape.copy()
Eric Kunzee5e26762020-10-13 16:11:07 -07005225
Matthew Haddon818ab902021-07-27 09:12:49 +01005226 output_shape[axis] = input1.shape[axis]
5227
5228 for tensor in remaining_inputs:
5229 output_shape[axis] += tensor.shape[axis]
5230
5231 return ser.addOutput(output_shape, input1.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005232
5233 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005234 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005235
5236 output_shape = a.shape.copy()
5237
5238 for i in range(len(output_shape)):
5239 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5240
Matthew Haddone807aae2021-10-11 18:12:58 +01005241 # Fix negative output shape if error_if test causes it
5242 if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
5243 output_shape = [i if i >= 1 else 1 for i in output_shape]
5244
5245 if error_name == ErrorIf.WrongOutputType:
5246 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5247 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5248 outputDType = rng.choice(wrong_dtypes)
5249 else:
5250 outputDType = a.dtype
5251
5252 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005253
5254 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005255 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005256 output_shape = shape.copy()
5257
5258 totalElements = 1
5259 for i in a.shape:
5260 totalElements *= i
5261
5262 # If there are any -1 elements, figure out what that dimension must be
5263 totalOutputElements = 1
5264 for i in output_shape:
5265 if i != -1:
5266 totalOutputElements *= i
5267
5268 # And fill it in
5269 for i in range(len(output_shape)):
5270 if output_shape[i] == -1:
5271 output_shape[i] = totalElements // totalOutputElements
5272
Matthew Haddone807aae2021-10-11 18:12:58 +01005273 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5274 for i in range(len(output_shape)):
5275 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5276
5277 if error_name == ErrorIf.WrongOutputType:
5278 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5279 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5280 outputDType = rng.choice(wrong_dtypes)
5281 else:
5282 outputDType = a.dtype
5283
5284 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005285
5286 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005287 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005288
Matthew Haddone807aae2021-10-11 18:12:58 +01005289 if error_name == ErrorIf.WrongOutputType:
5290 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5291 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5292 outputDType = rng.choice(wrong_dtypes)
5293 else:
5294 outputDType = a.dtype
5295
5296 if error_name == ErrorIf.SizeOutputShapeMismatch:
5297 output_shape = size.copy()
5298 for index in range(len(output_shape)):
5299 if output_shape[index] <= 2:
5300 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5301 else:
5302 output_shape[index] = output_shape[index] + rng.choice([-2, -1, 1, 2])
5303 else:
5304 output_shape = size.copy()
5305
5306 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005307
5308 @staticmethod
5309 def tileOp(ser, a, multiples):
5310
5311 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005312 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005313
5314 for i in range(len(output_shape)):
5315 output_shape[i] = a.shape[i] * multiples[i]
5316
Kevin Cheng550ccc52021-03-03 11:21:43 -08005317 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005318
5319 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005320 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005321 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005322
Kevin Cheng550ccc52021-03-03 11:21:43 -08005323 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005324
Matthew Haddone807aae2021-10-11 18:12:58 +01005325 if error_name == ErrorIf.IndexOutsideBounds:
5326 for i in range(len(output_shape)):
5327 output_shape[i] = a.shape[0]
5328 else:
5329 for i in range(len(output_shape)):
5330 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005331
Matthew Haddone807aae2021-10-11 18:12:58 +01005332 if error_name == ErrorIf.WrongOutputType:
5333 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5334 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5335 outputDType = rng.choice(wrong_dtypes)
5336 else:
5337 outputDType = a.dtype
5338
5339 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005340
5341 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08005342 def gatherOp(ser, values, indices):
5343 assert len(values.shape) == 3
5344 assert len(indices.shape) == 2
5345 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005346
Kevin Cheng77d0f762020-11-24 10:26:32 -08005347 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5348
Kevin Cheng550ccc52021-03-03 11:21:43 -08005349 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005350
5351 @staticmethod
5352 def scatterOp(ser, values_in, indices, input):
5353 assert len(values_in.shape) == 3
5354 assert len(indices.shape) == 2
5355 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005356 assert values_in.shape[0] == indices.shape[0] # N
5357 assert input.shape[1] == indices.shape[1] # W
5358 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005359
5360 output_shape = values_in.shape
5361
Kevin Cheng550ccc52021-03-03 11:21:43 -08005362 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005363
5364 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005365 def tableOp(ser, input, table_dtype):
5366 # Same shape as the input, but dtype dependent on table dtype
5367 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
5368 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
5369 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005370
5371 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005372 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005373 serializer,
5374 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005375 input,
5376 mode,
5377 stride,
5378 offset,
5379 shift,
5380 stride_fp,
5381 offset_fp,
5382 output_dims,
5383 input_dtype,
5384 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01005385 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08005386 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01005387 if error_name == ErrorIf.WrongRank:
5388 output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
5389 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005390 if error_name == ErrorIf.BatchMismatch:
5391 output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
5392 elif error_name == ErrorIf.ChannelMismatch:
5393 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
5394 else:
5395 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005396
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005397 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005398
5399 @staticmethod
5400 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005401 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005402
5403 @staticmethod
5404 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08005405 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005406 out_dtype = DType.INT32
5407 elif ifm.dtype == DType.INT16:
5408 out_dtype = DType.INT48
5409 elif ifm.dtype == DType.FLOAT:
5410 out_dtype = DType.FLOAT
5411 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005412 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005413
Kevin Cheng550ccc52021-03-03 11:21:43 -08005414 return ser.addOutput(output_shape, out_dtype)