blob: 105f016ef29e48ce5953dbb35313837dc19e5dcd [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
48DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080049Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070050ResizeMode = tosa.ResizeMode.ResizeMode()
51
52class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
54
Eric Kunzee5e26762020-10-13 16:11:07 -070055 def __init__(self):
56 pass
57
58 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010059 def getQinfo(testGen, dtype, error_name=None):
60
Les Bell30e46802021-07-23 09:43:31 +010061 if dtype == DType.INT8:
62 return testGen.randInt(-128, 128)
Matthew Haddone4ecdb22021-09-28 11:38:21 +010063 elif dtype == DType.UINT8:
Les Bell30e46802021-07-23 09:43:31 +010064 return testGen.randInt(0, 256)
Matthew Haddonc4cf0372021-10-11 09:38:10 +010065 elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.WeightZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
Matthew Haddone4ecdb22021-09-28 11:38:21 +010066 zero_point = testGen.randInt(-128, 128)
67 if zero_point == 0:
68 zero_point = 1
69 return zero_point
Les Bell30e46802021-07-23 09:43:31 +010070 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070071
72 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010073 def qgUnary(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070074 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone4ecdb22021-09-28 11:38:21 +010075 if error_name == ErrorIf.InputZeroPointNotZero:
76 qinfo.UnaryQuantInfo(
77 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype)
78 )
79 elif error_name == ErrorIf.OutputZeroPointNotZero:
80 qinfo.UnaryQuantInfo(
81 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype, error_name)
82 )
83 else:
84 qinfo.UnaryQuantInfo(
85 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
86 )
Eric Kunzee5e26762020-10-13 16:11:07 -070087 return qinfo
88
89 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010090 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070091 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010092 if isinstance(dtype_or_dtypeList, list):
93 # a list of [input, weights, accumulator] dtypes
94 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070095 else:
Les Bell30e46802021-07-23 09:43:31 +010096 # an int, [input, weights, accumulator] dtypes are the same
97 dtypeList = [dtype_or_dtypeList] * 3
Matthew Haddonc4cf0372021-10-11 09:38:10 +010098
99 if error_name == ErrorIf.InputZeroPointNotZero:
100 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
101 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
102 elif error_name == ErrorIf.WeightZeroPointNotZero:
103 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
104 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
105 else:
106 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
107 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
108
Les Bell30e46802021-07-23 09:43:31 +0100109 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -0700110 return qinfo
111
112 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100113 def qgMatmul(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700114 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100115 if error_name == ErrorIf.InputZeroPointNotZero:
116 qinfo.MatMulQuantInfo(
117 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype, error_name)
Kevin Chengacb550f2021-06-29 15:32:19 -0700118 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100119 else:
120 qinfo.MatMulQuantInfo(
121 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
122 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700123 return qinfo
124
125 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100126 def qgPad(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700127 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone807aae2021-10-11 18:12:58 +0100128 if error_name == ErrorIf.InputZeroPointNotZero:
129 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name))
130 else:
131 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700132 return qinfo
133
134 @staticmethod
135 def computeMultiplierAndShift(scaleFp, scale32):
136 # Derived from computeMultiplierAndShiftTosaScale32
137 # Provide a floating-point scaling factor and the scale32 parameter
138 # to compute the multiplier and shift
139
140 if scale32:
141 scaleBits = 31
142 else:
143 scaleBits = 15
144
145 m, shift = math.frexp(scaleFp)
146
147 if scaleFp < 0.0:
148 m = -m
149
150 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800151 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700152
153 if multiplier == (1 << scaleBits):
154 multiplier = multiplier // 2
155 shift = shift + 1
156
157 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100158 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
159
160 # Adjust multiplier such that shift is in allowed value range.
161 if shift == 0:
162 multiplier = multiplier // 4
163 shift = shift + 2
164 elif shift == 1:
165 multiplier = multiplier // 2
166 shift = shift + 1
167 elif shift == 63:
168 multiplier = multiplier * 2
169 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700170
Kevin Cheng550ccc52021-03-03 11:21:43 -0800171 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100172 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700173
174 return multiplier, shift
175
176
Kevin Cheng550ccc52021-03-03 11:21:43 -0800177class TosaTensorGen:
178 """Tensor generators create a shape list for the placeholder and const tensor
179 data operands for the operator. The actual random data is generated separately for each test."""
180
Eric Kunzee5e26762020-10-13 16:11:07 -0700181 def __init__(self):
182 pass
183
184 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100185 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800186 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700187 shape = testGen.makeShape(rank)
188
Matthew Haddonc2025212021-10-08 21:21:05 +0100189 # Constrict dimension size for large ranks when creating WrongRank tests
190 shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 shape_list = []
193 for i in range(pl + const):
194 shape_list.append(shape.copy())
195
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100196 if error_name == ErrorIf.RankMismatch:
197 if rank == 1 and i != 1:
198 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
199 elif i != 1:
200 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
201
Eric Kunzee5e26762020-10-13 16:11:07 -0700202 return shape_list
203
204 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100205 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800206 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700207
Matthew Haddon848efb42021-09-09 12:30:53 +0100208 if error_name != ErrorIf.WrongRank:
209 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700210
211 shape = testGen.makeShape(rank)
212
213 # Constrict the batch size?
214 if testGen.args.max_batch_size:
215 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100216
217 # Constrict dimension size for large ranks when creating WrongRank tests
218 shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700219
220 shape_list = []
221 for i in range(pl + const):
222 shape_list.append(shape.copy())
223
224 return shape_list
225
226 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100227 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800228 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800229
Kevin Cheng550ccc52021-03-03 11:21:43 -0800230 assert pl == 2
231 assert const == 0
232 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800233
234 values_in_shape = testGen.makeShape(rank)
235
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100236 # ignore max batch size if target shape is set
237 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800238 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
239
Kevin Cheng550ccc52021-03-03 11:21:43 -0800240 W = testGen.randInt(
241 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
242 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100243 # Constrict W if one dimension is too large to keep tensor size reasonable
244 if max(values_in_shape) > 5000:
245 W = testGen.randInt(0, 16)
246
Kevin Cheng77d0f762020-11-24 10:26:32 -0800247 input_shape = [values_in_shape[0], W, values_in_shape[2]]
248
249 shape_list = []
250 shape_list.append(values_in_shape.copy())
251 shape_list.append(input_shape.copy())
252
253 return shape_list
254
255 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100256 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700257 shape = testGen.makeShape(rank)
258
Kevin Cheng550ccc52021-03-03 11:21:43 -0800259 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700260
261 shape_list = []
262
263 # Choose one of the inputs to broadcast
264 bcast_idx = testGen.randInt(0, pl + const)
265 for i in range(pl + const):
266 shape_bcast = shape.copy()
267
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100268 if error_name == ErrorIf.RankMismatch:
269 bcast_idx = -1 # Turn off broadcast because we are not testing it
270 if rank == 1 and i != 1:
271 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
272 elif i != 1:
273 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
274
Eric Kunzee5e26762020-10-13 16:11:07 -0700275 # If the chosen input, pick a random index to broadcast
276 if i == bcast_idx:
277 fuzz_idx = testGen.randInt(0, rank)
278 shape_bcast[fuzz_idx] = 1
279
280 shape_list.append(shape_bcast)
281
282 return shape_list
283
284 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100285 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800286 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700287
Kevin Cheng550ccc52021-03-03 11:21:43 -0800288 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700289
290 # IFM dimensions are NHWC
291 ifm_shape = testGen.makeShape(rank)
292
293 # Constrict the batch size?
294 if testGen.args.max_batch_size:
295 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
296
297 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800298 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700299
300 # Generate a random OFM depth
301 ofm_depth = testGen.makeShape(1)[0]
302
303 # The filter dimensions are OHWI
304 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
305
306 # The bias is OC
307 bias_shape = np.asarray([ofm_depth])
308
309 return [ifm_shape, filter_shape, bias_shape]
310
311 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100312 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700313 pl, const = op["operands"]
314
315 assert rank == 5
316
317 # IFM dimensions are NDHWC
318 ifm_shape = testGen.makeShape(rank)
319
320 # Constrict the batch size?
321 if testGen.args.max_batch_size:
322 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
323
324 # Get the filter depth/height/width from the operator parameters
325 filter_dhw = op["filter"]
326
327 # Generate a random OFM channel
328 ofm_channel = testGen.makeShape(1)[0]
329
330 # The filter dimensions are ODHWI
331 filter_shape = np.asarray(
332 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
333 )
334
335 # The bias is OC
336 bias_shape = np.asarray([ofm_channel])
337
338 return [ifm_shape, filter_shape, bias_shape]
339
340 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100341 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800342 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700343
Kevin Cheng550ccc52021-03-03 11:21:43 -0800344 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700345
346 # IFM dimensions are NHWC
347 ifm_shape = testGen.makeShape(rank)
348
349 # Constrict the batch size?
350 if testGen.args.max_batch_size:
351 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
352
353 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800354 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700355
356 # Generate a random OFM depth
357 ofm_depth = testGen.makeShape(1)[0]
358
359 # The filter dimensions are OHWI
360 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
361
Kevin Cheng989cb052021-04-28 16:29:44 -0700362 # The bias is OC
363 bias_shape = np.asarray([ofm_depth])
364
365 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700366
367 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100368 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800369 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700370
Kevin Cheng550ccc52021-03-03 11:21:43 -0800371 assert rank == 4
372 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700373
374 # IFM dimensions are NHWC
375 ifm_shape = testGen.makeShape(rank)
376
377 # Constrict the batch size?
378 if testGen.args.max_batch_size:
379 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
380
381 # Get the filter height/width from the operator parameters
382 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800383 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700384
385 # Generate a random OFM depth, but don't let it get too big because
386 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800387 filter_m = (
388 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
389 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700390
391 # The filter dimensions are HWCM
392 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
393
394 # The bias is M * C
395 bias_shape = np.asarray([ifm_shape[3] * filter_m])
396
397 return [ifm_shape, filter_shape, bias_shape]
398
399 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100400 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800401 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700402
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100403 if error_name != ErrorIf.WrongRank:
404 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700405
406 input_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100407
408 # Constrict dimension size for large ranks when creating WrongRank tests
409 shape = TosaErrorIfArgGen.eiRestrictDimension(input_shape, error_name)
410
Kevin Chengacb550f2021-06-29 15:32:19 -0700411 filter_oc = testGen.rng.integers(
412 low=testGen.args.tensor_shape_range[0],
413 high=testGen.args.tensor_shape_range[1],
414 size=1,
415 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700416 filter_shape = np.asarray([filter_oc, input_shape[1]])
417
418 bias_shape = np.asarray([filter_oc])
419
420 return [input_shape, filter_shape, bias_shape]
421
422 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100423 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800424 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700425
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100426 if error_name != ErrorIf.WrongRank:
427 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800428 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700429
430 a_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100431
432 # Constrict dimension size for large ranks when creating WrongRank tests
433 shape = TosaErrorIfArgGen.eiRestrictDimension(a_shape, error_name)
434
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100435 # Get a random number for b_oc even if target shape is defined
436 b_oc = np.int32(
437 testGen.rng.integers(
438 low=testGen.args.tensor_shape_range[0],
439 high=testGen.args.tensor_shape_range[1],
440 size=1,
441 )
442 )[0]
443 # If N or H is large let b_oc be 1 to reduce output tensor size
444 if max(a_shape) > 1000:
445 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700446
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100447 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700448 return [a_shape, b_shape]
449
Matthew Haddon818ab902021-07-27 09:12:49 +0100450 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100451 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100452 pl, const = opName["operands"]
453 shape = testGen.makeShape(rank)
454
455 # Create extra tensors to concat.
456 # Take into account value of pl when getting maximum number of concats
457 num_tensors = testGen.randInt(0, 4)
458 shape_list = []
459 for i in range(pl + const + num_tensors):
460 shape_list.append(shape.copy())
461
462 return shape_list
463
464 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100465 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100466 # Split concat shape along axis to allow for multiple const inputs
467 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100468 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddon818ab902021-07-27 09:12:49 +0100469 return shapeList
470
Jeremy Johnson960985a2021-10-06 10:58:14 +0100471 # Create copy of shape we are going to split (so we don't alter shapeList)
472 shape = shapeList[0].copy()
473 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100474 new_shapeList = [shape.copy()]
475 length_on_axis = shape[axis]
476 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700477 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100478 # Calculate split on axis and remaining value
479 split_shape_val = int(shape[axis] / 2)
480 remaining_length = remaining_length - split_shape_val
481
482 # Append new shape, and set remaining shape
483 shape[axis] = split_shape_val
484 new_shapeList.append(shape.copy())
485 shape[axis] = remaining_length
486 if i == len(shapeList) - 3:
487 new_shapeList.append(shape.copy())
488
489 return new_shapeList
490
491
Eric Kunzee5e26762020-10-13 16:11:07 -0700492class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800493 """Argument generators create exhaustive or random lists of attributes for operators that take
494 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
495 tuples where the descriptive_name is appended to the test name and the arglist is expanded
496 as arguments to the operator build function."""
497
Eric Kunzee5e26762020-10-13 16:11:07 -0700498 def __init__(self):
499 pass
500
501 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100502 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800503 """A trivial argument generator for operators that don't take any
504 non-tensor arguments"""
505 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700506
507 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100508 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800509 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700510 axes = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700511 shape = shapeList[0]
512
Matthew Haddond6ce7252021-09-29 15:35:44 +0100513 if error_name == ErrorIf.AxisSmallerZero:
514 small_axis = testGen.rng.integers(-5, 0)
515 axes.append(("axis{}".format(small_axis), [small_axis]))
516 elif error_name == ErrorIf.AxisLargerRank:
517 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
518 axes.append(("axis{}".format(large_axis), [large_axis]))
519 else:
520 for a in range(0, len(shape)):
521 axes.append(("axis{}".format(a), [a]))
522
Eric Kunzee5e26762020-10-13 16:11:07 -0700523 return axes
524
525 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100526 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700527 arg_list = []
528
529 ifm_shape = shapeList[0]
530 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100531 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
532 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700533
Les Bell7aa69f42021-09-20 10:44:07 +0100534 # Check the rank
535 rank = 5 if opName.startswith("conv3d") else 4
536 assert len(ifm_shape) == rank
537 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700538
Les Bell7aa69f42021-09-20 10:44:07 +0100539 # kernel rank omits batch and channels
540 k_rank = rank - 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700541
Les Bell7aa69f42021-09-20 10:44:07 +0100542 # Generate comprehensive argument lists
543 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
544 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
545 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
546 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
547 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
548 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700549
Les Bell7aa69f42021-09-20 10:44:07 +0100550 # add some oversize argument values
551 if max(ifm_shape) < 64:
552 bigPadding = 9
553 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
554 bigStride = 8
555 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
556 bigDilation = 7
557 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100558
559 # There are too many parameter combinations, so generate them sparsely
Les Bell7aa69f42021-09-20 10:44:07 +0100560 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
561 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
562 if sparsity < 13:
563 sparsity = 1
564 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
565 sparsity += 1
Les Bellf414b3c2021-09-06 11:29:46 +0100566 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100567 for s in sorted(list(strides)):
568 for p in sorted(list(paddings)):
569 for d in sorted(list(dilations)):
570 if (n % sparsity == 0
571 # padding must not exceed the kernel size ?
572 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
573 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
574 # the padded shape must exceed the kernel size
575 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
576 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
577 # the padded shape must exceed the dilation
578 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
579 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
580 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100581 arg_list.append(
582 (
583 "st{}_pad{}_dilat{}".format(
584 "".join([str(x) for x in s]),
585 "".join([str(x) for x in p]),
586 "".join([str(x) for x in d]),
587 ),
588 [s, p, d],
589 )
590 )
591 n += 1
592
Kevin Cheng1533b852021-09-01 12:51:58 -0700593 return arg_list
594
595 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100596 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700597 arg_list = []
598
599 ifm_shape = shapeList[0]
600 filter_shape = shapeList[1]
601
602 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800603 assert len(ifm_shape) == 4
604 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700605
Les Bell7aa69f42021-09-20 10:44:07 +0100606 # Generate comprehensive argument lists
607 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
608 paddings = {x for x in itertools.product(*([p_vals] * 2))}
609 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
610 strides = {x for x in itertools.product(*([s_vals] * 2))}
611 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
612 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700613
Les Bell7aa69f42021-09-20 10:44:07 +0100614 # add some oversize argument values
615 if max(ifm_shape) < 64:
616 bigPadding = 9
617 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
618 bigStride = 8
619 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
620 bigDilation = 7
621 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700622
Les Bell7aa69f42021-09-20 10:44:07 +0100623 # There are too many parameter combinations, so generate them sparsely
624 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
625 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
626 if sparsity < 13:
627 sparsity = 1
628 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
629 sparsity += 1
630 n = 0
631 for s in sorted(list(strides)):
632 for p in sorted(list(paddings)):
633 for d in sorted(list(dilations)):
634 if n % sparsity == 0:
635 # Determine the output shape
636 oh = (
637 ifm_shape[1]
638 - filter_shape[1]
639 - (filter_shape[1] - 1) * (d[0] - 1)
640 + 2 * p[0]
641 ) // s[0] + 1
642 ow = (
643 ifm_shape[2]
644 - filter_shape[2]
645 - (filter_shape[2] - 1) * (d[1] - 1)
646 + 2 * p[1]
647 ) // s[1] + 1
648 os = [ifm_shape[0], oh, ow, filter_shape[0]]
649 arg_list.append(
650 (
651 "st{}_pad{}_dilat{}_os{}".format(
652 "".join([str(x) for x in s]),
653 "".join([str(x) for x in p]),
654 "".join([str(x) for x in d]),
655 "x".join([str(x) for x in os]),
656 ),
657 [s, p, d, os],
658 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800659 )
Les Bell7aa69f42021-09-20 10:44:07 +0100660 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700661
662 return arg_list
663
664 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100665 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700666 arg_list = []
667 rank = len(shapeList[0])
668
Les Bell7ffccce2021-07-28 15:37:02 +0100669 # Exhaustively test combinations of padding on each side of each dimension
670 # - the range of padding values is defined by pad_min and pad_max
671 # - for padding >9, the name format needs to be more distinctive
672 pad_min, pad_max = 0, 1
673 pad_values = [x for x in range(pad_min, pad_max + 1)]
Matthew Haddone807aae2021-10-11 18:12:58 +0100674 if error_name == ErrorIf.PadSmallerZero:
675 pad_values = [x for x in range(-2, 0)]
Les Bell7ffccce2021-07-28 15:37:02 +0100676 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
677 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700678
Les Bell7ffccce2021-07-28 15:37:02 +0100679 for paddings in shape_pad_values:
680 name = "pad"
681 for r in range(rank):
682 before, after = paddings[r]
683 name = f"{name}{before}{after}"
684 arg_list.append((name, [np.array(paddings)]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700685
686 return arg_list
687
688 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100689 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700690 arg_list = []
691
692 shape = shapeList[0]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100693 if error_name != ErrorIf.WrongRank:
694 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700695
Les Bell7aa69f42021-09-20 10:44:07 +0100696 # Generate comprehensive argument lists
697 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
698 paddings = {x for x in itertools.product(*([p_vals] * 4))}
699 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
700 strides = {x for x in itertools.product(*([s_vals] * 2))}
701 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
702 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700703
Les Bell7aa69f42021-09-20 10:44:07 +0100704 # add some oversize argument values
705 bigStride = 7
706 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
707 bigKernel = 6
708 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
709 if max(shape) < 64:
710 # padding must be less than the kernel size
711 bigPadding = bigKernel - 1
712 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700713
Les Bell7aa69f42021-09-20 10:44:07 +0100714 # There are too many parameter combinations, so generate them sparsely
715 sparsity = len(paddings) * len(strides) * len(kernels) // 500 + 1
716 n = 0
717 for s in sorted(list(strides)):
718 for p in sorted(list(paddings)):
719 for k in sorted(list(kernels)):
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100720 if error_name in [ErrorIf.StrideSmallerOne, ErrorIf.KernelSmallerOne, ErrorIf.PadSmallerZero, ErrorIf.PadLargerEqualKernel]:
721 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(testGen, error_name, s, p, k)
722 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
723 arg_list.append(
724 (
725 "st{}_kern{}_pad{}".format(
726 "".join([str(x) for x in sNew]),
727 "".join([str(x) for x in kNew]),
728 "".join([str(x) for x in pNew]),
729 ),
730 [sNew, pNew, kNew],
731 )
732 )
733 elif (n % sparsity == 0
Les Bell7aa69f42021-09-20 10:44:07 +0100734 # padding must not exceed the kernel size
735 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
736 # the padded shape must exceed the kernel size
737 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
738 ):
739 arg_list.append(
740 (
741 "st{}_kern{}_pad{}".format(
742 "".join([str(x) for x in s]),
743 "".join([str(x) for x in k]),
744 "".join([str(x) for x in p]),
745 ),
746 [s, p, k],
747 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800748 )
Les Bell7aa69f42021-09-20 10:44:07 +0100749 n += 1
750
Eric Kunzee5e26762020-10-13 16:11:07 -0700751 return arg_list
752
753 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100754 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700755 arg_list = []
756
757 # Enumerate the output types here
758 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800759 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700760 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800761 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700762 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800763 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700764 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800765 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700766 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800767 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700768 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800769 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700770
771 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800772 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700773
774 return arg_list
775
776 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100777 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700778 arg_list = []
779
780 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100781 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100782 if dtype in [DType.UINT8, DType.INT8] and error_name == ErrorIf.OutputZeroPointNotZero:
783 continue
784 if inDtype == DType.UINT8 and dtype != DType.INT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100785 # The only output dtype for UINT8 is INT8, skip all other combinations
786 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100787 if inDtype != DType.INT8 and dtype == DType.UINT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100788 # The only input dtype for UINT8 is INT8, skip all other combinations
789 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100790 if error_name == ErrorIf.WrongOutputType and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype):
791 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100792
Kevin Cheng550ccc52021-03-03 11:21:43 -0800793 for scale32 in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100794 if error_name == ErrorIf.ScaleTrue and scale32 == False:
795 continue
796 elif error_name == ErrorIf.ScaleNotTrue and scale32 == True:
797 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800798 for double_round in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100799 if error_name == ErrorIf.ScaleNotTrue and double_round == False:
800 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800801 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700802
Matthew Haddonc2025212021-10-08 21:21:05 +0100803 if inDtype == DType.INT48 and scale32 and error_name != ErrorIf.ScaleTrue:
Eric Kunzee5e26762020-10-13 16:11:07 -0700804 # Illegal condition. Must be scale32=False
805 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100806 if double_round and not scale32 and error_name != ErrorIf.ScaleNotTrue:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100807 # Illegal condition. ERROR_IF(!scale32 && double_round)
808 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700809
Kevin Cheng550ccc52021-03-03 11:21:43 -0800810 arg_list.append(
811 (
812 "out{}_sc{}_dr{}_pc{}".format(
813 DTypeNames[dtype],
814 int(scale32),
815 int(double_round),
816 int(per_channel),
817 ),
818 [dtype, scale32, double_round, per_channel],
819 )
820 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700821
822 return arg_list
823
Kevin Chengaee1fac2020-11-11 13:54:06 -0800824 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100825 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800826 arg_list = []
827
828 if dtype is DType.INT32:
829 for p in range(testGen.args.num_rand_permutations):
830
831 shift = testGen.randInt(0, 32)
832
Kevin Cheng550ccc52021-03-03 11:21:43 -0800833 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800834 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100835 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800836
837 return arg_list
838
839 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100840 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800841 arg_list = []
842
Kevin Cheng550ccc52021-03-03 11:21:43 -0800843 arg_list.append(("roundTrue", [True]))
844 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800845
846 return arg_list
847
Eric Kunzee5e26762020-10-13 16:11:07 -0700848 # Helper function for reshape. Gets some factors of a larger number.
849 @staticmethod
850 def getFactors(val, start=1):
851 factors = []
852
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100853 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700854 if (val % i) == 0:
855 factors.append(i)
856
857 return factors
858
859 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100860 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700861 arg_list = []
862
863 origShape = shapeList[0]
864
865 totalElements = 1
866 for s in origShape:
867 totalElements *= s
868
869 # This code is NOT fast. Fortunately, the numbers are fairly small.
870 factors = TosaArgGen.getFactors(totalElements)
871
872 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100873 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800874 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700875 continue
876
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100877 found = True
878 # escape_counter breaks while loop if it continues on for too long
879 escape_counter = 0
880 while found:
881 newShape = []
882 # Generate newShape ensuring it isn't a duplicate
883 remainingElements = totalElements
884 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100885 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100886 # pick rank-1 factors
887 newShape.append(shuffledFactors[0])
888 remainingElements = remainingElements // shuffledFactors[0]
889 shuffledFactors = testGen.rng.permutation(
890 TosaArgGen.getFactors(remainingElements)
891 )
892 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700893
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100894 # Toss in a -1 sometimes
895 minusOne = testGen.randInt(0, newRank * 4)
896 if minusOne < newRank:
897 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700898
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100899 # Check for duplicates
900 found = False
901 for name, other_shape in arg_list:
902 if other_shape[0] == newShape:
903 found = True
904 break
905
906 escape_counter += 1
907 if escape_counter >= 100:
908 break
909
910 if not found:
911 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700912
913 return arg_list
914
Eric Kunzee5e26762020-10-13 16:11:07 -0700915 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100916 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700917 arg_list = []
918
919 ifm_shape = shapeList[0]
920
Matthew Haddone807aae2021-10-11 18:12:58 +0100921
922 if error_name == ErrorIf.IndexOutsideBounds:
923 incorrect_large_index = range(len(ifm_shape)+1, 2*len(ifm_shape)+1)
924 incorrect_small_index = range(-len(ifm_shape), 0)
925 permutations = [p for p in itertools.permutations(incorrect_large_index)]
926 permutations.extend([p for p in itertools.permutations(incorrect_small_index)])
927 elif error_name == ErrorIf.IndexUsedTwice:
928 # Create list with a duplicated index
929 perm_range = list(range(len(ifm_shape)))
930 index_choice = testGen.rng.choice(range(len(perm_range)))
931 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
932 permutations = [p for p in itertools.permutations(perm_range)]
933
934
935 else:
936 # Get all permutations
937 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700938
Jeremy Johnsona6185572021-06-21 15:55:35 +0100939 # Limit to possible permutations from shape dimension or argument setting
940 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700941
Jeremy Johnsona6185572021-06-21 15:55:35 +0100942 # Get random permutation generator that uses all permutations
943 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700944
Jeremy Johnsona6185572021-06-21 15:55:35 +0100945 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -0700946 arg_list = [
947 ("perm{}".format(p), [random_permutations[p].tolist()])
948 for p in range(limit)
949 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700950 return arg_list
951
952 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100953 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700954 arg_list = []
955
956 ifm_shape = shapeList[0]
957 rank = len(ifm_shape)
958
959 for p in range(testGen.args.num_rand_permutations):
Matthew Haddone807aae2021-10-11 18:12:58 +0100960 start = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700961 size = []
962
Kevin Cheng550ccc52021-03-03 11:21:43 -0800963 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700964
965 for i in range(rank):
966 if ifm_shape[i] > 1:
Matthew Haddone807aae2021-10-11 18:12:58 +0100967 start.append(testGen.randInt(0, ifm_shape[i]))
968 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700969
970 # Invalid slice size?
971 if size[i] == 0:
972 valid = False
973 else:
Matthew Haddone807aae2021-10-11 18:12:58 +0100974 start.append(0)
Eric Kunzee5e26762020-10-13 16:11:07 -0700975 size.append(1)
976
977 if valid:
Matthew Haddone807aae2021-10-11 18:12:58 +0100978 # If ERROR_IF test required then incorrect start, size will be returned
979 start, size = TosaErrorIfArgGen.eiSliceErrorIf(testGen, error_name, ifm_shape, start, size)
980 arg_list.append(("perm{}".format(p), [start, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700981 return arg_list
982
983 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100984 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700985 arg_list = []
986
987 ifm_shape = shapeList[0]
988 rank = len(ifm_shape)
989
990 for p in range(testGen.args.num_rand_permutations):
991
992 # Pick a few random, but small multiple values
993 # because otherwise this has a tendency to generate
994 # enormous tensors
995 multiples = []
996 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +0100997 if ifm_shape[i] > 1000:
998 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
999 multiples.append(1)
1000 elif max(ifm_shape) > 1000:
1001 multiples.append(2)
1002 else:
1003 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001004 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001005
1006 return arg_list
1007
1008 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001009 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001010 arg_list = []
1011
1012 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +01001013 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001014
1015 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +01001016 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001017 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +01001018 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001019 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +01001020 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001021 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +01001022 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001023 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001024 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001025 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +01001026 elif error_name == ErrorIf.WrongInputType:
1027 # If an incorrect input type is used then we set a 'correct'
1028 # output type to avoid other errors
1029 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -07001030 else:
1031 continue
1032
1033 for outputDType in outputDTypeList:
1034 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -07001035 # Randomly generate legal output dimensions and shift
1036 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +01001037 # A output_dim of 1 will cause offset to exceed allowed range
1038 # so minimum value 2 produced below
1039 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
1040 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
1041 output_dims[0] += 1
1042 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
1043 output_dims[1] += 1
1044
Kevin Cheng77d0f762020-11-24 10:26:32 -08001045 in_center_h = (ifm_shape[1] - 1) / 2.0
1046 in_center_w = (ifm_shape[2] - 1) / 2.0
1047 out_center_h = (output_dims[0] - 1) / 2.0
1048 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -07001049
Kevin Cheng77d0f762020-11-24 10:26:32 -08001050 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
1051 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
1052 fp_offset_y = in_center_h - fp_stride_y * out_center_h
1053 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -07001054
Kevin Cheng77d0f762020-11-24 10:26:32 -08001055 if outputDType == DType.FLOAT:
1056 shift = 0
1057 stride = [0, 0]
1058 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001059 stride_fp = [fp_stride_y, fp_stride_x]
1060 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +01001061
1062 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001063 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001064 testGen,
1065 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001066 mode,
1067 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001068 shapeList,
1069 outputDType,
1070 shift,
1071 stride,
1072 stride_fp,
1073 offset,
1074 offset_fp
1075 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001076 else:
1077 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001078
Kevin Cheng550ccc52021-03-03 11:21:43 -08001079 arg_list.append(
1080 (
1081 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001082 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001083 output_dims[0],
1084 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001085 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001086 stride_fp[0],
1087 stride_fp[1],
1088 offset_fp[0],
1089 offset_fp[1],
1090 ),
1091 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001092 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001093 stride,
1094 offset,
1095 shift,
1096 stride_fp,
1097 offset_fp,
1098 output_dims,
1099 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001100 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001101 ],
1102 )
1103 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001104 else:
1105 shift = 11
1106 unit = float(1 << shift)
1107 stride_y = int(round(fp_stride_y * unit))
1108 stride_x = int(round(fp_stride_x * unit))
1109 offset_y = int(round(fp_offset_y * unit))
1110 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001111
Kevin Cheng550ccc52021-03-03 11:21:43 -08001112 while (
Matthew Haddone86fd342021-09-07 16:12:21 +01001113 stride_y >= (16 << shift)
1114 or stride_x >= (16 << shift)
1115 or offset_y >= (16 << shift)
1116 or offset_x >= (16 << shift)
1117 or offset_y <= (-16 << shift)
1118 or offset_x <= (-16 << shift)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001119 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001120 shift = shift - 1
1121 unit = float(1 << shift)
1122 stride_y = int(round(fp_stride_y * unit))
1123 stride_x = int(round(fp_stride_x * unit))
1124 offset_y = int(round(fp_offset_y * unit))
1125 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001126
Kevin Cheng550ccc52021-03-03 11:21:43 -08001127 stride = [stride_y, stride_x]
1128 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001129
1130 stride_fp = [0.0, 0.0]
1131 offset_fp = [0.0, 0.0]
1132
Matthew Haddone86fd342021-09-07 16:12:21 +01001133 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001134 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001135 testGen,
1136 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001137 mode,
1138 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001139 shapeList,
1140 outputDType,
1141 shift,
1142 stride,
1143 stride_fp,
1144 offset,
1145 offset_fp
1146 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001147 else:
1148 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001149
Kevin Cheng550ccc52021-03-03 11:21:43 -08001150 arg_list.append(
1151 (
1152 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001153 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001154 shift,
1155 output_dims[0],
1156 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001157 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001158 stride[0],
1159 stride[1],
1160 offset[0],
1161 offset[1],
1162 ),
1163 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001164 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001165 stride,
1166 offset,
1167 shift,
1168 stride_fp,
1169 offset_fp,
1170 output_dims,
1171 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001172 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001173 ],
1174 )
1175 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001176
1177 return arg_list
1178
Matthew Haddon1c00b712021-10-01 15:51:03 +01001179 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001180 # CondIf generates the condition values here.
1181 # Convert to tensors in the build function, along with the
1182 # then and else blocks
1183 arg_list = []
1184
1185 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001186 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001187
1188 return arg_list
1189
Matthew Haddon1c00b712021-10-01 15:51:03 +01001190 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001191 # While loop: 0 iterations, 1, more than 1
1192 arg_list = []
1193
1194 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001195 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001196
1197 return arg_list
1198
Matthew Haddone86fd342021-09-07 16:12:21 +01001199class TosaErrorIfArgGen:
1200
1201 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001202 def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
Matthew Haddone86fd342021-09-07 16:12:21 +01001203
1204 if outputDType == DType.FLOAT:
1205 if error_name == ErrorIf.StrideSmallerEqualZero:
1206 stride_fp = testGen.rng.random(size=[2]) - 2
1207 elif error_name == ErrorIf.ShiftNotZero:
1208 shift = testGen.rng.integers(1, 5)
1209 elif error_name == ErrorIf.StrideLargerDimension:
1210 shape = shapeList[0]
1211 transform_height = testGen.rng.choice([False, True])
1212 if transform_height:
1213 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1214 else:
1215 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1216 else:
1217 if error_name == ErrorIf.StrideSmallerEqualZero:
1218 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1219 elif error_name == ErrorIf.ShiftSmallerOne:
1220 shift = testGen.rng.integers(-3, 1)
1221 if shift <= 0:
1222 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1223 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1224 else:
1225 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1226 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1227 elif error_name == ErrorIf.ShiftLargerEleven:
1228 shift = np.int16(testGen.rng.integers(12, 15))
1229 elif error_name == ErrorIf.StrideLargerDimension:
1230 shape = shapeList[0]
1231 transform_height = testGen.rng.choice([False, True])
1232 if transform_height:
1233 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1234 else:
1235 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1236 elif error_name == ErrorIf.StrideLargerEqualMax:
1237 stride = [(16 << shift) + 1, (16 << shift) + 1]
1238 elif error_name == ErrorIf.OffsetLargerEqualMax:
1239 offset = [(16 << shift) + 1, (16 << shift) + 1]
1240 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1241 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1242
Matthew Haddon1c00b712021-10-01 15:51:03 +01001243
Matthew Haddon848efb42021-09-09 12:30:53 +01001244 if error_name == ErrorIf.WrongOutputType:
1245 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1246 incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1247 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1248 incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
1249 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1250 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
1251 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1252 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
1253 elif dtype == DType.FLOAT:
1254 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
1255 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001256
Matthew Haddon848efb42021-09-09 12:30:53 +01001257 return shift, stride, stride_fp, offset, offset_fp, outputDType
1258
Matthew Haddone807aae2021-10-11 18:12:58 +01001259
Matthew Haddon848efb42021-09-09 12:30:53 +01001260 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001261 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
1262 if (error_name == ErrorIf.StrideSmallerOne
1263 # padding must not exceed the kernel size
1264 and pad[0] < kernel[0] and pad[1] < kernel[0] and pad[2] < kernel[1] and pad[3] < kernel[1]):
1265 wrongStride = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1266 return wrongStride, pad, kernel
1267 elif error_name == ErrorIf.PadSmallerZero:
1268 wrongPad = (testGen.rng.choice([-1, -2, -3]),
1269 testGen.rng.choice([-1, -2, -3]),
1270 testGen.rng.choice([-1, -2, -3]),
1271 testGen.rng.choice([-1, -2, -3]))
1272 return stride, wrongPad, kernel
1273 elif error_name == ErrorIf.KernelSmallerOne:
1274 wrongKernel = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1275 return stride, pad, wrongKernel
1276 elif error_name == ErrorIf.PadLargerEqualKernel:
1277 wrongPad = (testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1278 testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1279 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]),
1280 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]))
1281 return stride, wrongPad, kernel
1282 else:
1283 return None, None, None
1284
Matthew Haddone807aae2021-10-11 18:12:58 +01001285
Matthew Haddonc2025212021-10-08 21:21:05 +01001286 @staticmethod
1287 def eiRescaleWrongOutputType(input_dtype, output_dtype):
1288 if input_dtype == DType.INT8:
1289 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1290 return True
1291 if input_dtype in [DType.INT16, DType.INT32]:
1292 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1293 return True
1294 elif input_dtype == DType.INT48:
1295 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1296 return True
1297 elif input_dtype == DType.UINT8:
1298 if output_dtype != DType.INT8:
1299 return True
1300 return False
1301
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001302
1303 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001304 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1305 # Mess up input/output tensors for ERROR_IF checks
1306 if error_name == "WrongInputList":
1307 add_input = testGen.rng.choice([True, False])
1308 if add_input:
1309 input_list.append('eiDummyInput')
1310 else:
1311 input_list = input_list[:-1]
1312 if error_name == "WrongOutputList":
1313 add_output = testGen.rng.choice([True, False])
1314 if add_output:
1315 output_list.append('eiDummyOutput')
1316 else:
1317 output_list = []
1318 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001319
Matthew Haddone807aae2021-10-11 18:12:58 +01001320
Matthew Haddonc2025212021-10-08 21:21:05 +01001321 @staticmethod
1322 def eiRestrictDimension(shape, error_name):
1323 # Restrict dimension size if rank is large for WrongRank Error_If
1324 # This will keep the test sizes reasonably small
1325 if error_name == ErrorIf.WrongRank:
1326 if len(shape) > 4:
1327 shape[4] = 1
1328
1329 return shape
1330
Matthew Haddone807aae2021-10-11 18:12:58 +01001331
1332 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
1333 if error_name == ErrorIf.StartSmallerZero:
1334 newStart = []
1335 for i in range(len(input_shape)):
1336 newStart.append(testGen.rng.choice([-3, -2, -1]))
1337 return newStart, size
1338 elif error_name == ErrorIf.SizeSmallerEqualZero:
1339 newSize = []
1340 for i in range(len(input_shape)):
1341 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
1342 return start, newSize
1343 elif error_name == ErrorIf.StartSizeOutsideBounds:
1344 newStart, newSize = [], []
1345 for i in range(len(input_shape)):
1346 newStart.append(input_shape[i]-1)
1347 newSize.append(testGen.rng.choice([2, 3, 4]))
1348 return newStart, newSize
1349 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
1350 remove = testGen.rng.choice([True, False])
1351 if remove:
1352 newStart = start[1:]
1353 newSize = size[1:]
1354 else:
1355 newStart = start
1356 newStart.append(1)
1357 newSize = size
1358 newSize.append(1)
1359 return newStart, newSize
1360 else:
1361 return start, size
1362
Matthew Haddone86fd342021-09-07 16:12:21 +01001363class TosaErrorValidator:
1364
Matthew Haddon848efb42021-09-09 12:30:53 +01001365 @staticmethod
1366 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
1367 # Check ERROR_IF statements
1368
1369 for val_fcn in validator_fcns:
1370 val_result = val_fcn(True, **kwargs)
1371
1372 validator_name = val_result['error_name']
1373 error_result = val_result['error_result']
1374 error_reason = val_result['error_reason']
1375
1376 if error_result:
1377 if error_name == validator_name:
1378 serializer.setExpectedReturnCode(2, error_reason)
1379 else:
1380 print(f"Multiple ERROR_IF checks hit \nError required: {error_name}, Error_produced: {validator_name}")
1381 return None # Return None to delete test if wrong ERROR_IF is hit
1382 else:
1383 if error_name == validator_name:
1384 print(f"No ERROR_IF hit for {error_name}")
1385 return None
1386
1387 @staticmethod
1388 def evWrongInputType(check=False, **kwargs):
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001389 all_dtypes = {DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
Matthew Haddon848efb42021-09-09 12:30:53 +01001390
1391 # Find the unsupported input data types
1392 assert 'op' in kwargs
1393 op = kwargs['op']
1394 input_dtypes = op['types']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001395
1396 allowed_input_dtypes = {t[0] if isinstance(t, list) else t for t in input_dtypes}
1397 wrong_input_dtypes = list(all_dtypes - allowed_input_dtypes)
Matthew Haddon848efb42021-09-09 12:30:53 +01001398
1399 error_name = ErrorIf.WrongInputType
1400 param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
1401 error_result = False
1402 error_reason = "Input data type not supported for this operator"
1403
1404 if check:
1405 input_dtype = kwargs['input_dtype']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001406 if op['op'] == Op.FULLY_CONNECTED:
1407 if input_dtype not in allowed_input_dtypes:
1408 error_result = True
1409 elif input_dtype not in input_dtypes:
Matthew Haddon848efb42021-09-09 12:30:53 +01001410 error_result = True
1411
1412 info_dict = {
1413 "error_name": error_name,
1414 "error_result": error_result,
1415 "error_reason": error_reason,
1416 "param_reqs": param_reqs
1417 }
1418 return info_dict
1419
1420 @staticmethod
1421 def evWrongOutputType(check=False, **kwargs):
1422 error_name = ErrorIf.WrongOutputType
1423 param_reqs = {"rank": None, "dtype": None, "shape": None}
1424 error_result = False
1425 error_reason = "Output data type not supported for this configuration of operator"
1426
1427 if check:
1428 input_dtype = kwargs['input_dtype']
1429 output_dtype = kwargs['output_dtype']
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001430 op = kwargs['op']
Matthew Haddon848efb42021-09-09 12:30:53 +01001431
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001432 if op['op'] == Op.RESIZE:
1433 mode = kwargs['mode']
1434 if (
1435 (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
1436 (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
1437 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1438 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1439 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1440 ):
1441 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001442 elif op['op'] == Op.RESCALE:
1443 if input_dtype == DType.INT8:
1444 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1445 error_result = True
1446 if input_dtype in [DType.INT16, DType.INT32]:
1447 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1448 error_result = True
1449 elif input_dtype == DType.INT48:
1450 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1451 error_result = True
1452 elif input_dtype == DType.UINT8:
1453 if output_dtype != DType.INT8:
1454 error_result = True
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001455 elif op['op'] in [Op.FULLY_CONNECTED, Op.MATMUL]:
1456 if (
1457 (input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1458 (input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1459 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1460 ):
1461 error_result = True
1462 elif op['op'] == Op.ARGMAX:
1463 if input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] and output_dtype != DType.INT32:
1464 error_result = True
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001465 else:
1466 if output_dtype != input_dtype:
1467 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001468
1469 info_dict = {
1470 "error_name": error_name,
1471 "error_result": error_result,
1472 "error_reason": error_reason,
1473 "param_reqs": param_reqs
1474 }
1475 return info_dict
1476
1477 @staticmethod
1478 def evWrongRank(check=False, **kwargs):
1479 all_ranks = (1, 2, 3, 4, 5)
1480
1481 # Make a list of incorrect ranks
1482 assert 'op' in kwargs
1483 op = kwargs['op']
1484 rmin, rmax = op['rank']
1485 rank_range = range(rmin, rmax + 1)
1486 incorrect_ranks = list(set(all_ranks) - set(rank_range))
Matthew Haddonc2025212021-10-08 21:21:05 +01001487 # Remove small incorrect ranks to avoid index errors
1488 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
Matthew Haddon848efb42021-09-09 12:30:53 +01001489 # Set minimum incorrect rank to 3 to avoid index error
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001490 if op['op'] in [Op.RESIZE]:
Matthew Haddon848efb42021-09-09 12:30:53 +01001491 incorrect_ranks = [3, 5]
1492
1493 error_name = ErrorIf.WrongRank
1494 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1495 error_result = False
1496 error_reason = "Rank not supported for this operator"
1497
1498 if check:
1499 input_shape = kwargs['input_shape']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001500
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001501 if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4:
Matthew Haddon848efb42021-09-09 12:30:53 +01001502 error_result = True
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001503 elif op['op'] == Op.FULLY_CONNECTED and len(input_shape) != 2:
1504 error_result = True
1505 elif op['op'] == Op.MATMUL and len(input_shape) != 3:
1506 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001507 else:
1508 if len(input_shape) not in rank_range:
1509 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001510
1511 info_dict = {
1512 "error_name": error_name,
1513 "error_result": error_result,
1514 "error_reason": error_reason,
1515 "param_reqs": param_reqs
1516 }
1517 return info_dict
1518
1519 @staticmethod
1520 def evWrongInputList(check=False, **kwargs):
1521 error_name = ErrorIf.WrongInputList
1522 param_reqs = {"rank": None, "dtype": None, "shape": None}
1523 error_result = False
1524 error_reason = "Op input list does not match expected input"
1525
1526 if check:
1527 op = kwargs['op']
1528 input_list = kwargs['input_list']
1529 num_operands = kwargs['num_operands']
Matthew Haddone807aae2021-10-11 18:12:58 +01001530 # both PAD, TRANSPOSE add an extra const layer in the build function
1531 if op['op'] in [Op.PAD, Op.TRANSPOSE]:
1532 if len(input_list) != num_operands + 1:
1533 error_result = True
1534 else:
1535 if len(input_list) != num_operands:
1536 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001537
1538 info_dict = {
1539 "error_name": error_name,
1540 "error_result": error_result,
1541 "error_reason": error_reason,
1542 "param_reqs": param_reqs
1543 }
1544 return info_dict
1545
1546 @staticmethod
1547 def evWrongOutputList(check=False, **kwargs):
1548 error_name = ErrorIf.WrongOutputList
1549 param_reqs = {"rank": None, "dtype": None, "shape": None}
1550 error_result = False
1551 error_reason = "Op output list does not match expected output"
1552
1553 if check:
1554 output_list = kwargs['output_list']
1555 # Note this will be incorrect if an operator returns more than one output
1556 if len(output_list) != 1:
1557 error_result = True
1558
1559 info_dict = {
1560 "error_name": error_name,
1561 "error_result": error_result,
1562 "error_reason": error_reason,
1563 "param_reqs": param_reqs
1564 }
1565 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01001566
1567 @staticmethod
1568 def evMaxDimExceeded(check=False, **kwargs):
1569 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01001570 param_reqs = {
1571 "rank": [4,4],
1572 "dtype": [DType.INT8],
1573 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
1574 }
Matthew Haddone86fd342021-09-07 16:12:21 +01001575 error_result = False
1576 error_reason = "At least one maximum dimension is larger than 16384"
1577
1578 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001579 input_shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001580 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
1581 if ((input_shape[1] > 16384) or
1582 (input_shape[2] > 16384) or
1583 (output_shape[0] > 16384) or
1584 (output_shape[1] > 16384)):
1585 error_result = True
1586
1587 info_dict = {
1588 "error_name": error_name,
1589 "error_result": error_result,
1590 "error_reason": error_reason,
1591 "param_reqs": param_reqs
1592 }
1593 return info_dict
1594
1595 @staticmethod
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001596 def evBatchMismatch(check=False, **kwargs):
1597 error_name = ErrorIf.BatchMismatch
1598 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1599 error_result = False
1600 error_reason = "Input batch size not equal to output batch size"
1601
1602 assert 'op' in kwargs
1603 op = kwargs['op']
1604 rmin, rmax = op['rank']
1605 rank_range = range(rmin, rmax + 1)
1606
1607 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001608 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001609 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1610
1611 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
1612 error_result = True
1613
1614 info_dict = {
1615 "error_name": error_name,
1616 "error_result": error_result,
1617 "error_reason": error_reason,
1618 "param_reqs": param_reqs
1619 }
1620 return info_dict
1621
1622 @staticmethod
1623 def evChannelMismatch(check=False, **kwargs):
1624 error_name = ErrorIf.ChannelMismatch
1625 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1626 error_result = False
1627 error_reason = "Input channel size not equal to output channel size"
1628
1629 assert 'op' in kwargs
1630 op = kwargs['op']
1631 rmin, rmax = op['rank']
1632 rank_range = range(rmin, rmax + 1)
1633
1634 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001635 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001636 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1637 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
1638 error_result = True
1639
1640 info_dict = {
1641 "error_name": error_name,
1642 "error_result": error_result,
1643 "error_reason": error_reason,
1644 "param_reqs": param_reqs
1645 }
1646 return info_dict
1647
1648 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001649 def evStrideSmallerEqualZero(check=False, **kwargs):
1650 error_name = ErrorIf.StrideSmallerEqualZero
1651 param_reqs = {"rank": None, "dtype": None, "shape": None}
1652 error_result = False
1653 error_reason = "Stride value smaller than or equal zero"
1654
1655 if check:
1656 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001657 output_dtype = kwargs['output_dtype']
1658 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
1659 stride = kwargs['stride'] # Work around wrong input/output type tests
1660 elif output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001661 stride = kwargs['stride_fp']
Matthew Haddon848efb42021-09-09 12:30:53 +01001662 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1663 stride = kwargs['stride_fp'] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01001664 else:
1665 stride = kwargs['stride']
1666
1667 if min(stride) <= 0:
1668 error_result = True
1669
1670 info_dict = {
1671 "error_name": error_name,
1672 "error_result": error_result,
1673 "error_reason": error_reason,
1674 "param_reqs": param_reqs
1675 }
1676 return info_dict
1677
1678 @staticmethod
1679 def evStrideLargerEqualMax(check=False, **kwargs):
1680 error_name = ErrorIf.StrideLargerEqualMax
1681 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1682 error_result = False
1683 error_reason = "Stride value larger than or equal to maximum value"
1684
1685 if check:
1686 shift = kwargs['shift']
1687 input_dtype = kwargs['input_dtype']
1688 stride = kwargs['stride']
1689 if input_dtype in [DType.INT8, DType.INT16]:
1690 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1691 error_result = True
1692 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1693 error_result = True
1694
1695 info_dict = {
1696 "error_name": error_name,
1697 "error_result": error_result,
1698 "error_reason": error_reason,
1699 "param_reqs": param_reqs
1700 }
1701 return info_dict
1702
1703
1704 @staticmethod
1705 def evStrideLargerDimension(check=False, **kwargs):
1706 error_name = ErrorIf.StrideLargerDimension
1707 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1708 error_result = False
1709 error_reason = "Stride value larger than or equal to H/W dimension"
1710
1711 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001712 shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001713 input_dtype = kwargs['input_dtype']
1714 stride = kwargs['stride_fp']
1715
1716 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1717 error_result = True
1718
1719 info_dict = {
1720 "error_name": error_name,
1721 "error_result": error_result,
1722 "error_reason": error_reason,
1723 "param_reqs": param_reqs
1724 }
1725 return info_dict
1726
1727
1728 @staticmethod
1729 def evOffsetSmallerEqualMin(check=False, **kwargs):
1730 error_name = ErrorIf.OffsetSmallerEqualMin
1731 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1732 error_result = False
1733 error_reason = "Offset value smaller than or equal to minimum value"
1734
1735 if check:
1736 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001737 output_dtype = kwargs['output_dtype']
1738 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001739 offset = kwargs['offset_fp']
1740 else:
1741 offset = kwargs['offset']
1742
1743 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
1744 error_result = True
1745 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
1746 error_result = True
1747
1748 info_dict = {
1749 "error_name": error_name,
1750 "error_result": error_result,
1751 "error_reason": error_reason,
1752 "param_reqs": param_reqs
1753 }
1754 return info_dict
1755
1756 @staticmethod
1757 def evOffsetLargerEqualMax(check=False, **kwargs):
1758 error_name = ErrorIf.OffsetLargerEqualMax
1759 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1760 error_result = False
1761 error_reason = "Offset value larger than or equal to maximum value"
1762
1763 if check:
1764 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001765 output_dtype = kwargs['output_dtype']
1766 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001767 offset = kwargs['offset_fp']
1768 else:
1769 offset = kwargs['offset']
1770
1771 if shift >= 0:
1772 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
1773 error_result = True
1774
1775 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
1776 error_result = True
1777 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
1778 error_result = True
1779
1780 info_dict = {
1781 "error_name": error_name,
1782 "error_result": error_result,
1783 "error_reason": error_reason,
1784 "param_reqs": param_reqs
1785 }
1786 return info_dict
1787
1788 @staticmethod
1789 def evShiftNotZero(check=False, **kwargs):
1790 error_name = ErrorIf.ShiftNotZero
1791 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1792 error_result = False
1793 error_reason = "Shift value must be zero for float input"
1794
1795 if check:
1796 shift = kwargs['shift']
1797 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001798 output_dtype = kwargs['output_dtype']
1799 if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
Matthew Haddone86fd342021-09-07 16:12:21 +01001800 error_result = True
1801
1802 info_dict = {
1803 "error_name": error_name,
1804 "error_result": error_result,
1805 "error_reason": error_reason,
1806 "param_reqs": param_reqs
1807 }
1808 return info_dict
1809
1810
1811 @staticmethod
1812 def evShiftSmallerOne(check=False, **kwargs):
1813 error_name = ErrorIf.ShiftSmallerOne
1814 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1815 error_result = False
1816 error_reason = "Shift value smaller than one"
1817
1818 if check:
1819 shift = kwargs['shift']
1820 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001821 output_dtype = kwargs['output_dtype']
1822 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001823 error_result = True
1824
1825 info_dict = {
1826 "error_name": error_name,
1827 "error_result": error_result,
1828 "error_reason": error_reason,
1829 "param_reqs": param_reqs
1830 }
1831 return info_dict
1832
1833 @staticmethod
1834 def evShiftLargerEleven(check=False, **kwargs):
1835 error_name = ErrorIf.ShiftLargerEleven
1836 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1837 error_result = False
1838 error_reason = "Shift value larger than eleven"
1839
1840 if check:
1841 shift = kwargs['shift']
1842 if shift > 11:
1843 error_result = True
1844
1845 info_dict = {
1846 "error_name": error_name,
1847 "error_result": error_result,
1848 "error_reason": error_reason,
1849 "param_reqs": param_reqs
1850 }
1851 return info_dict
1852
1853
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001854 @staticmethod
1855 def evRankMismatch(check=False, **kwargs):
1856 error_name = ErrorIf.RankMismatch
1857 param_reqs = {"rank": None, "dtype": None, "shape": None}
1858 error_result = False
1859 error_reason = "Input Rank does not match output rank"
1860
1861 if check:
1862 input1_shape = kwargs['input1'].shape
1863 input2_shape = kwargs['input2'].shape
1864 output_shape = kwargs['result_tensor'].shape
1865 if (len(input1_shape) != len(output_shape)) or (len(input2_shape) != len(output_shape)):
1866 error_result = True
1867
1868 info_dict = {
1869 "error_name": error_name,
1870 "error_result": error_result,
1871 "error_reason": error_reason,
1872 "param_reqs": param_reqs
1873 }
1874 return info_dict
1875
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001876 @staticmethod
1877 def evInputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001878 op = kwargs['op']
1879 inputDtypes = op['types'].copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001880 # If inputDtypes is a list then only the first two elements are INT8 inputs
1881 if isinstance(inputDtypes, list):
1882 inputDtypes = inputDtypes[2:]
1883
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001884 if DType.INT8 in inputDtypes:
1885 inputDtypes.remove(DType.INT8)
1886 if DType.UINT8 in inputDtypes:
1887 inputDtypes.remove(DType.UINT8)
1888
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001889 error_name = ErrorIf.InputZeroPointNotZero
1890 param_reqs = {
1891 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001892 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001893 "shape": None
1894 }
1895 error_result = False
1896 error_reason = "Input DType not INT8 and zero point not 0"
1897
1898 if check:
1899 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01001900 if isinstance(kwargs['qinfo'], tuple):
1901 qinfo = kwargs['qinfo']
1902 input_zero_point = qinfo[0]
1903 else:
1904 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1905 qinfo = kwargs['qinfo'].ints
1906 input_zero_point = qinfo[0][1]
1907
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001908 if op['op'] == Op.MATMUL:
1909 input1_dtype = kwargs['input_dtype']
1910 input2_dtype = kwargs['input2_dtype']
1911 qinfo = kwargs['qinfo'].ints
1912 input1_zero_point = qinfo[0][1]
1913 input2_zero_point = qinfo[1][1]
1914 if (input1_dtype != DType.INT8 and input1_zero_point != 0) or (input2_dtype != DType.INT8 and input2_zero_point != 0):
1915 error_result = True
1916 else:
1917 if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
1918 error_result = True
1919
1920 info_dict = {
1921 "error_name": error_name,
1922 "error_result": error_result,
1923 "error_reason": error_reason,
1924 "param_reqs": param_reqs
1925 }
1926 return info_dict
1927
1928
1929 @staticmethod
1930 def evWeightZeroPointNotZero(check=False, **kwargs):
1931 op = kwargs['op']
1932
1933 # exclude inputs with INT8 weights
1934 inputDtypes = [t for t in op['types']
1935 if not isinstance(t, list) or t[1] != DType.INT8]
1936
1937 error_name = ErrorIf.WeightZeroPointNotZero
1938 param_reqs = {
1939 "rank": None,
1940 "dtype": inputDtypes,
1941 "shape": None
1942 }
1943 error_result = False
1944 error_reason = "Weight DType not INT8 and zero point not 0"
1945
1946 if check:
1947 weight_dtype = kwargs['weight_dtype']
1948 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
1949 qinfo = kwargs['qinfo'].ints
1950 weight_zero_point = qinfo[1][1]
1951 if weight_dtype != DType.INT8 and weight_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001952 error_result = True
1953
1954 info_dict = {
1955 "error_name": error_name,
1956 "error_result": error_result,
1957 "error_reason": error_reason,
1958 "param_reqs": param_reqs
1959 }
1960 return info_dict
1961
1962
1963 @staticmethod
1964 def evOutputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001965 op = kwargs['op']
1966 inputDtypes = op['types'].copy()
1967 if DType.INT8 in inputDtypes:
1968 inputDtypes.remove(DType.INT8)
1969 if DType.UINT8 in inputDtypes:
1970 inputDtypes.remove(DType.UINT8)
1971
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001972 error_name = ErrorIf.OutputZeroPointNotZero
1973 param_reqs = {
1974 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001975 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001976 "shape": None
1977 }
1978 error_result = False
1979 error_reason = "Output DType not INT8 and zero point not 0"
1980
1981 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001982 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01001983 output_dtype = kwargs['output_dtype']
1984 if isinstance(kwargs['qinfo'], tuple):
1985 qinfo = kwargs['qinfo']
1986 output_zero_point = qinfo[1]
1987 else:
1988 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1989 qinfo = kwargs['qinfo'].ints
1990 output_zero_point = qinfo[1][1]
1991 if op['op'] == Op.AVG_POOL2D:
1992 if input_dtype != DType.INT8 and output_zero_point != 0:
1993 error_result = True
1994 elif output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001995 error_result = True
1996
1997 info_dict = {
1998 "error_name": error_name,
1999 "error_result": error_result,
2000 "error_reason": error_reason,
2001 "param_reqs": param_reqs
2002 }
2003 return info_dict
2004
Matthew Haddond6ce7252021-09-29 15:35:44 +01002005 @staticmethod
2006 def evAxisSmallerZero(check=False, **kwargs):
2007 error_name = ErrorIf.AxisSmallerZero
2008 param_reqs = {"rank": None, "dtype": None, "shape": None}
2009 error_result = False
2010 error_reason = "Axis smaller than zero"
2011
2012 if check:
2013 axis = kwargs['axis']
2014 if axis < 0:
2015 error_result = True
2016
2017 info_dict = {
2018 "error_name": error_name,
2019 "error_result": error_result,
2020 "error_reason": error_reason,
2021 "param_reqs": param_reqs
2022 }
2023 return info_dict
2024
2025
2026 @staticmethod
2027 def evAxisLargerRank(check=False, **kwargs):
2028 error_name = ErrorIf.AxisLargerRank
2029 param_reqs = {"rank": None, "dtype": None, "shape": None}
2030 error_result = False
2031 error_reason = "Axis larger than rank"
2032
2033 if check:
2034 axis = kwargs['axis']
2035 shape = kwargs['input_shape']
2036 if axis > len(shape):
2037 error_result = True
2038
2039 info_dict = {
2040 "error_name": error_name,
2041 "error_result": error_result,
2042 "error_reason": error_reason,
2043 "param_reqs": param_reqs
2044 }
2045 return info_dict
2046
2047
2048 @staticmethod
2049 def evShapeOfAxisNotOne(check=False, **kwargs):
2050 error_name = ErrorIf.ShapeOfAxisNotOne
2051 param_reqs = {"rank": None, "dtype": None, "shape": None}
2052 error_result = False
2053 error_reason = "shape[axis] is not equal to 1"
2054
2055 if check:
2056 axis = kwargs['axis']
2057 shape = kwargs['output_shape']
2058 if (0 <= axis < len(shape)) and shape[axis] != 1:
2059 error_result = True
2060
2061 info_dict = {
2062 "error_name": error_name,
2063 "error_result": error_result,
2064 "error_reason": error_reason,
2065 "param_reqs": param_reqs
2066 }
2067 return info_dict
2068
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002069
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002070 @staticmethod
2071 def evPadSmallerZero(check=False, **kwargs):
2072 error_name = ErrorIf.PadSmallerZero
2073 param_reqs = {"rank": None, "dtype": None, "shape": None}
2074 error_result = False
2075 error_reason = "At least one pad is smaller than zero"
2076
2077 if check:
Matthew Haddone807aae2021-10-11 18:12:58 +01002078 op = kwargs['op']
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002079 pad = kwargs['pad']
Matthew Haddone807aae2021-10-11 18:12:58 +01002080 if op['op'] == Op.PAD:
2081 for padding in pad:
2082 if min(padding) < 0:
2083 error_result = True
2084 else:
2085 if min(pad) < 0:
2086 error_result = True
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002087
2088 info_dict = {
2089 "error_name": error_name,
2090 "error_result": error_result,
2091 "error_reason": error_reason,
2092 "param_reqs": param_reqs
2093 }
2094 return info_dict
2095
2096
2097 @staticmethod
2098 def evPadLargerEqualKernel(check=False, **kwargs):
2099 error_name = ErrorIf.PadLargerEqualKernel
2100 param_reqs = {"rank": None, "dtype": None, "shape": None}
2101 error_result = False
2102 error_reason = "At least one pad is larger than kernel dimension"
2103
2104 if check:
2105 pad = kwargs['pad']
2106 kernel = kwargs['kernel']
2107 if min(pad) > 0 and min(kernel) > 1:
2108 if pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]:
2109 error_result = True
2110
2111 info_dict = {
2112 "error_name": error_name,
2113 "error_result": error_result,
2114 "error_reason": error_reason,
2115 "param_reqs": param_reqs
2116 }
2117 return info_dict
2118
2119 @staticmethod
2120 def evPoolingOutputShapeMismatch(check=False, **kwargs):
2121 error_name = ErrorIf.PoolingOutputShapeMismatch
2122 param_reqs = {"rank": None, "dtype": None, "shape": None}
2123 error_result = False
2124 error_reason = "Mismatch between output shape provided and expected output shape"
2125
2126 if check:
2127 pad = kwargs['pad']
2128 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
2129
2130 kernel = kwargs['kernel']
2131 kernel_y, kernel_x = kernel[0], kernel[1]
2132
2133 input_shape = kwargs['input_shape']
2134 IH, IW = input_shape[1], input_shape[2]
2135
2136 output_shape = kwargs['output_shape']
2137 OH, OW = output_shape[1], output_shape[2]
2138
2139 stride = kwargs['stride']
2140 stride_y, stride_x = stride[0], stride[1]
2141
2142 # calculate correct height, width dimensions
2143 if stride_x != 0 and stride_y != 0:
2144 y_correct = (IH + pad_top + pad_bottom + stride_y - kernel_y) // stride_y
2145 x_correct = (IW + pad_left + pad_right + stride_x - kernel_x) // stride_x
2146
2147 # ensure parameters are valid
2148 params_valid = (min(kernel) >= 1 and min(stride) >= 1 and min(pad) >= 0
2149 and not (pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]))
2150
2151 if params_valid and (OH != y_correct or OW != x_correct):
2152 error_result = True
2153
2154 info_dict = {
2155 "error_name": error_name,
2156 "error_result": error_result,
2157 "error_reason": error_reason,
2158 "param_reqs": param_reqs
2159 }
2160 return info_dict
2161
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002162 @staticmethod
2163 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
2164 error_name = ErrorIf.ArgmaxOutputShapeMismatch
2165 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2166 error_result = False
2167 error_reason = "Mismatch between output shape provided and expected output shape"
2168
2169 if check:
2170 output_shape = kwargs['output_shape']
2171 input_shape = kwargs['input_shape']
2172 axis = kwargs['axis']
2173
2174 dimension_match = True
2175 axis_shift = 0
2176
2177 # Check that rank is correct before trying to check dimensions
2178 if (len(input_shape) - 1) == len(output_shape):
2179 for i in range(len(input_shape)):
2180 if i == axis:
2181 axis_shift = 1
2182 continue
2183 if input_shape[i] != output_shape[i - axis_shift]:
2184 dimension_match = False
2185
2186 if not dimension_match:
2187 error_result = True
2188
2189 info_dict = {
2190 "error_name": error_name,
2191 "error_result": error_result,
2192 "error_reason": error_reason,
2193 "param_reqs": param_reqs
2194 }
2195 return info_dict
2196
2197 @staticmethod
2198 def evArgmaxOutputRankMismatch(check=False, **kwargs):
2199 error_name = ErrorIf.ArgmaxOutputRankMismatch
2200 param_reqs = {"rank": None, "dtype": None, "shape": None}
2201 error_result = False
2202 error_reason = "Mismatch between output shape provided and expected output shape"
2203
2204 if check:
2205 output_shape = kwargs['output_shape']
2206 input_shape = kwargs['input_shape']
2207 axis = kwargs['axis']
2208 valid_params = axis >= 0 and axis < len(input_shape)
2209
2210 if valid_params and (len(input_shape) - 1) != len(output_shape):
2211 error_result = True
2212
2213 info_dict = {
2214 "error_name": error_name,
2215 "error_result": error_result,
2216 "error_reason": error_reason,
2217 "param_reqs": param_reqs
2218 }
2219 return info_dict
2220
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002221
2222 @staticmethod
2223 def evKernelSmallerOne(check=False, **kwargs):
2224 error_name = ErrorIf.KernelSmallerOne
2225 param_reqs = {"rank": None, "dtype": None, "shape": None}
2226 error_result = False
2227 error_reason = "At least one kernel dimension is smaller than zero"
2228
2229 if check:
2230 kernel = kwargs['kernel']
2231 if min(kernel) < 1:
2232 error_result = True
2233
2234 info_dict = {
2235 "error_name": error_name,
2236 "error_result": error_result,
2237 "error_reason": error_reason,
2238 "param_reqs": param_reqs
2239 }
2240 return info_dict
2241
2242 @staticmethod
2243 def evStrideSmallerOne(check=False, **kwargs):
2244 error_name = ErrorIf.StrideSmallerOne
2245 param_reqs = {"rank": None, "dtype": None, "shape": None}
2246 error_result = False
2247 error_reason = "At least one stride dimension is smaller than zero"
2248
2249 if check:
2250 stride = kwargs['stride']
2251 if min(stride) < 1:
2252 error_result = True
2253
2254 info_dict = {
2255 "error_name": error_name,
2256 "error_result": error_result,
2257 "error_reason": error_reason,
2258 "param_reqs": param_reqs
2259 }
2260 return info_dict
2261
Matthew Haddonc2025212021-10-08 21:21:05 +01002262 @staticmethod
2263 def evScaleTrue(check=False, **kwargs):
2264 error_name = ErrorIf.ScaleTrue
2265 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
2266 error_result = False
2267 error_reason = "Scale set to true but input type is INT48"
2268
2269 if check:
2270 input_dtype = kwargs['input_dtype']
2271 scale32 = kwargs['scale32']
2272 if scale32 and input_dtype == DType.INT48:
2273 error_result = True
2274
2275 info_dict = {
2276 "error_name": error_name,
2277 "error_result": error_result,
2278 "error_reason": error_reason,
2279 "param_reqs": param_reqs
2280 }
2281 return info_dict
2282
2283 @staticmethod
2284 def evScaleNotTrue(check=False, **kwargs):
2285 error_name = ErrorIf.ScaleNotTrue
2286 param_reqs = {"rank": None, "dtype": None, "shape": None}
2287 error_result = False
2288 error_reason = "Scale set to false but double round set to true"
2289
2290 if check:
2291 scale32 = kwargs['scale32']
2292 double_round = kwargs['double_round']
2293 if not scale32 and double_round:
2294 error_result = True
2295
2296 info_dict = {
2297 "error_name": error_name,
2298 "error_result": error_result,
2299 "error_reason": error_reason,
2300 "param_reqs": param_reqs
2301 }
2302 return info_dict
2303
Matthew Haddone807aae2021-10-11 18:12:58 +01002304 @staticmethod
2305 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
2306 error_name = ErrorIf.TensorSizeInputOutputMismatch
2307 param_reqs = {"rank": None, "dtype": None, "shape": None}
2308 error_result = False
2309 error_reason = "Input tensor size does not match output tensor size"
2310
2311 if check:
2312 input_shape = kwargs['input_shape']
2313 output_shape = kwargs['output_shape']
2314 input_size = np.prod(input_shape)
2315 output_size = np.prod(output_shape)
2316 if input_size != output_size:
2317 error_result = True
2318
2319 info_dict = {
2320 "error_name": error_name,
2321 "error_result": error_result,
2322 "error_reason": error_reason,
2323 "param_reqs": param_reqs
2324 }
2325 return info_dict
2326
2327 @staticmethod
2328 def evStartSmallerZero(check=False, **kwargs):
2329 error_name = ErrorIf.StartSmallerZero
2330 param_reqs = {"rank": None, "dtype": None, "shape": None}
2331 error_result = False
2332 error_reason = "Starting point smaller than zero"
2333
2334 if check:
2335 input_shape = kwargs['input_shape']
2336 start = kwargs['start']
2337 rank = len(input_shape)
2338 if len(start) == rank:
2339 for index in range(rank):
2340 if start[index] < 0:
2341 error_result = True
2342
2343 info_dict = {
2344 "error_name": error_name,
2345 "error_result": error_result,
2346 "error_reason": error_reason,
2347 "param_reqs": param_reqs
2348 }
2349 return info_dict
2350
2351
2352 @staticmethod
2353 def evSizeSmallerEqualZero(check=False, **kwargs):
2354 error_name = ErrorIf.SizeSmallerEqualZero
2355 param_reqs = {"rank": None, "dtype": None, "shape": None}
2356 error_result = False
2357 error_reason = "Size smaller than or equal to zero"
2358
2359 if check:
2360 input_shape = kwargs['input_shape']
2361 size = kwargs['size']
2362 rank = len(input_shape)
2363 if len(size) == rank:
2364 for index in range(rank):
2365 if size[index] <= 0:
2366 error_result = True
2367
2368 info_dict = {
2369 "error_name": error_name,
2370 "error_result": error_result,
2371 "error_reason": error_reason,
2372 "param_reqs": param_reqs
2373 }
2374 return info_dict
2375
2376
2377 @staticmethod
2378 def evStartSizeOutsideBounds(check=False, **kwargs):
2379 error_name = ErrorIf.StartSizeOutsideBounds
2380 param_reqs = {"rank": None, "dtype": None, "shape": None}
2381 error_result = False
2382 error_reason = "starting point plus size larger than input dimension"
2383
2384 if check:
2385 input_shape = kwargs['input_shape']
2386 start = kwargs['start']
2387 size = kwargs['size']
2388 rank = len(input_shape)
2389 if len(start) == rank and len(size) == rank:
2390 for index in range(rank):
2391 if start[index] + size[index] > input_shape[index]:
2392 error_result = True
2393
2394 info_dict = {
2395 "error_name": error_name,
2396 "error_result": error_result,
2397 "error_reason": error_reason,
2398 "param_reqs": param_reqs
2399 }
2400 return info_dict
2401
2402
2403 @staticmethod
2404 def evSizeOutputShapeMismatch(check=False, **kwargs):
2405 error_name = ErrorIf.SizeOutputShapeMismatch
2406 param_reqs = {"rank": None, "dtype": None, "shape": None}
2407 error_result = False
2408 error_reason = "Size does not match output dimension"
2409
2410 if check:
2411 input_shape = kwargs['input_shape']
2412 output_shape = kwargs['output_shape']
2413 size = kwargs['size']
2414 rank = len(input_shape)
2415 if len(size) == rank:
2416 for index in range(rank):
2417 if size[index] != output_shape[index]:
2418 error_result = True
2419
2420 info_dict = {
2421 "error_name": error_name,
2422 "error_result": error_result,
2423 "error_reason": error_reason,
2424 "param_reqs": param_reqs
2425 }
2426 return info_dict
2427
2428 @staticmethod
2429 def evInputSizeStartLengthMismatch(check=False, **kwargs):
2430 error_name = ErrorIf.InputSizeStartLengthMismatch
2431 param_reqs = {"rank": None, "dtype": None, "shape": None}
2432 error_result = False
2433 error_reason = "rank of input not equal to length of start or size"
2434
2435 if check:
2436 input_shape = kwargs['input_shape']
2437 start = kwargs['start']
2438 size = kwargs['size']
2439 rank = len(input_shape)
2440 if rank != len(start) or rank != len(size):
2441 error_result = True
2442
2443 info_dict = {
2444 "error_name": error_name,
2445 "error_result": error_result,
2446 "error_reason": error_reason,
2447 "param_reqs": param_reqs
2448 }
2449 return info_dict
2450
2451 @staticmethod
2452 def evIndexOutsideBounds(check=False, **kwargs):
2453 error_name = ErrorIf.IndexOutsideBounds
2454 param_reqs = {"rank": None, "dtype": None, "shape": None}
2455 error_result = False
2456 error_reason = "Index outside of allowed bounds"
2457
2458 if check:
2459 input_shape = kwargs['input_shape']
2460 perms = kwargs['perms']
2461 rank = len(input_shape)
2462
2463 for index in perms:
2464 if index < 0 or index > rank:
2465 error_result = True
2466
2467 info_dict = {
2468 "error_name": error_name,
2469 "error_result": error_result,
2470 "error_reason": error_reason,
2471 "param_reqs": param_reqs
2472 }
2473 return info_dict
2474
2475 @staticmethod
2476 def evIndexUsedTwice(check=False, **kwargs):
2477 error_name = ErrorIf.IndexUsedTwice
2478 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2479 error_result = False
2480 error_reason = "Index used multiple times"
2481
2482 if check:
2483 input_shape = kwargs['input_shape']
2484 perms = kwargs['perms']
2485 rank = len(input_shape)
2486
2487 unique_indices = []
2488 for index in perms:
2489 if index in unique_indices:
2490 error_result = True
2491 else:
2492 unique_indices.append(index)
2493
2494 info_dict = {
2495 "error_name": error_name,
2496 "error_result": error_result,
2497 "error_reason": error_reason,
2498 "param_reqs": param_reqs
2499 }
2500 return info_dict
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002501
2502
Matthew Haddonb724efc2021-08-25 16:40:29 +01002503class TosaInvalidValidator:
2504
2505 @staticmethod
2506 def ivWrongDataTypeOrModeResize(**kwargs):
2507 input_dtype = kwargs["input_dtype"]
2508 args = kwargs["args"]
2509 mode = args[0]
2510 stride = args[1]
2511 stride_fp = args[4]
2512 output_dtype = args[8]
2513
2514 if mode == ResizeMode.BILINEAR:
2515 # Invalid output data type / Invalid input datatype
2516 return (
2517 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
2518 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
2519 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
2520 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2521 )
2522 elif mode == ResizeMode.NEAREST:
2523 # Invalid output data type / Invalid input datatype
2524 return (
2525 (input_dtype != output_dtype) or
2526 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2527 )
2528 else:
2529 # Invalid resize mode
2530 return True
2531
2532 @staticmethod
2533 def ivBadStride(**kwargs):
2534 input_dtype = kwargs["input_dtype"]
2535 args = kwargs["args"]
2536 stride_x = args[1][0]
2537 stride_y = args[1][1]
2538 stride_fp_x = args[4][0]
2539 stride_fp_y = args[4][1]
2540
2541 if input_dtype == DType.FLOAT:
2542 if stride_fp_x <= 0 or stride_fp_y <= 0:
2543 # Negative or zero stride
2544 return True
2545 else:
2546 if stride_x <= 0 or stride_y <= 0:
2547 # Negative or zero stride
2548 return True
2549 return False
2550
2551
Matthew Haddonb724efc2021-08-25 16:40:29 +01002552 @staticmethod
2553 def ivHeightWidthSmallerZero(**kwargs):
2554 opName = kwargs['opName']
2555
2556 inputShapes = kwargs['shapeList']
2557 input = inputShapes[0]
2558 if not opName.endswith("pool2d"):
2559 filter = inputShapes[1]
2560
2561 args = kwargs['args']
2562 strides = args[0]
2563 padding = args[1]
2564 dilations = args[2]
2565 if opName.endswith("pool2d"):
2566 kernel = args[2]
2567
2568 if opName.startswith('conv2d'):
2569 h = (
2570 input[1]
2571 - filter[1]
2572 - (filter[1] - 1) * (dilations[0] - 1)
2573 + padding[0]
2574 + padding[1]
2575 ) // strides[0] + 1
2576
2577 w = (
2578 input[2]
2579 - filter[2]
2580 - (filter[2] - 1) * (dilations[1] - 1)
2581 + padding[2]
2582 + padding[3]
2583 ) // strides[1] + 1
2584 elif opName.startswith("depthwise_conv2d"):
2585 h = (
2586 input[1]
2587 - filter[0]
2588 - (filter[0] - 1) * (dilations[0] - 1)
2589 + padding[0]
2590 + padding[1]
2591 ) // strides[0] + 1
2592
2593 w = (
2594 input[2]
2595 - filter[1]
2596 - (filter[1] - 1) * (dilations[1] - 1)
2597 + padding[2]
2598 + padding[3]
2599 ) // strides[1] + 1
2600 elif opName.endswith("pool2d"):
2601 h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
2602 w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
2603 else:
2604 assert False, "Unrecognized Op"
2605
2606 if h <= 0 or w <= 0:
2607 # Invalid parameter combination
2608 return True
2609 return False
2610
2611 @staticmethod
2612 def ivNonPositiveOutputShape(**kwargs):
2613 args = kwargs['args']
2614 output_shape = args[3]
2615 if output_shape[1] <= 0 or output_shape[2] <= 0:
2616 # Negative output shape
2617 return True
2618 return False
2619
2620
Kevin Cheng550ccc52021-03-03 11:21:43 -08002621
Eric Kunzee5e26762020-10-13 16:11:07 -07002622class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002623 # Maximum rank of tensor supported by test generator.
2624 TOSA_TENSOR_MAX_RANK = 6
2625
Eric Kunzee5e26762020-10-13 16:11:07 -07002626 def __init__(self, args):
2627 self.args = args
2628 self.basePath = args.output_dir
2629 self.random_seed = args.random_seed
2630 self.ser = None
2631 self.rng = np.random.default_rng(self.random_seed)
2632 self.createDynamicOpLists()
2633 self.initOpListDefaults()
2634 self.quantGen = TosaQuantGen()
2635 # Force makeShape to do a specific starting shape
2636 self.targetted_shape = None
2637
2638 def createSerializer(self, opName, testPath):
2639 self.testPath = os.path.join(opName, testPath)
2640
2641 fullPath = os.path.join(self.basePath, self.testPath)
2642 os.makedirs(fullPath, exist_ok=True)
2643 self.ser = ts.TosaSerializer(fullPath)
2644
2645 def getSerializer(self):
2646 return self.ser
2647
2648 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002649 with open(
2650 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
2651 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07002652 fd.write(self.ser.serialize())
2653
Kevin Cheng550ccc52021-03-03 11:21:43 -08002654 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
2655 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07002656
Matthew Haddon74567092021-07-16 15:38:20 +01002657 def resetRNG(self, seed=None):
2658 if seed == None:
2659 seed = self.random_seed + 1
2660 self.rng = np.random.default_rng(seed)
2661
Eric Kunzee5e26762020-10-13 16:11:07 -07002662 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07002663 if dtype == DType.BOOL:
2664 np_dt = np.bool
2665 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07002666 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07002667 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07002668 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002669 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002670 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
2671 elif dtype == DType.UINT8:
2672 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002673 elif dtype == DType.INT16:
2674 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
2675 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002676 return np.int32(
2677 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
2678 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002679 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002680 return np.int64(
2681 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
2682 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002683 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002684 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002685 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002686 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002687
Kevin Cheng989cb052021-04-28 16:29:44 -07002688 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07002689 placeholders = []
2690
Kevin Cheng989cb052021-04-28 16:29:44 -07002691 assert len(shape_list) == len(dtype_list)
2692
2693 for idx, shape in enumerate(shape_list):
2694 arr = self.getRandTensor(shape, dtype_list[idx])
2695 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002696
2697 return placeholders
2698
Kevin Cheng989cb052021-04-28 16:29:44 -07002699 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07002700 consts = []
2701
Kevin Cheng989cb052021-04-28 16:29:44 -07002702 assert len(shape_list) == len(dtype_list)
2703
2704 for idx, shape in enumerate(shape_list):
2705 arr = self.getRandTensor(shape, dtype_list[idx])
2706 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002707
2708 return consts
2709
2710 def makeShape(self, rank):
2711 if self.targetted_shape:
2712 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002713 return np.int32(
2714 self.rng.integers(
2715 low=self.args.tensor_shape_range[0],
2716 high=self.args.tensor_shape_range[1],
2717 size=rank,
2718 )
2719 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002720
2721 def setTargetShape(self, shape):
2722 self.targetted_shape = shape
2723
2724 def randInt(self, low=0, high=256):
2725 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
2726
2727 def getRandNumberDType(self, dtype):
2728 if dtype == DType.FLOAT:
2729 return self.rng.random()
2730 elif dtype == DType.BOOL:
2731 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07002732 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07002733 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07002734 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07002735 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002736 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07002737 elif dtype == DType.INT16:
2738 low, high = (-32768, 32768)
2739 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002740 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07002741 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002742 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07002743 # Special size
2744 return np.int64(self.rng.integers(low, high, size=1))[0]
2745 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002746 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002747
2748 return np.int32(self.rng.integers(low, high, size=1))[0]
2749
2750 def shapeStr(self, shape):
2751
2752 sStr = []
2753 # Convert to strings
2754 for i in shape:
2755 sStr.append(str(i))
2756
Kevin Cheng550ccc52021-03-03 11:21:43 -08002757 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002758
2759 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07002760 if isinstance(t, list):
2761 assert len(t) >= 2
2762 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002763 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002764 if t == DType.BOOL:
2765 return "b"
2766 elif t == DType.INT4:
2767 return "i4"
2768 elif t == DType.INT8:
2769 return "i8"
2770 elif t == DType.UINT8:
2771 return "u8"
2772 elif t == DType.INT16:
2773 return "i16"
2774 elif t == DType.INT32:
2775 return "i32"
2776 elif t == DType.INT48:
2777 return "i48"
2778 elif t == DType.FLOAT:
2779 return "float"
2780 else:
2781 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002782
2783 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002784 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08002785 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07002786 return 4
2787 elif t == DType.INT8:
2788 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08002789 elif t == DType.UINT8:
2790 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07002791 elif t == DType.INT16:
2792 return 16
2793 elif t == DType.INT32:
2794 return 32
2795 elif t == DType.INT48:
2796 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +01002797 elif t == DType.FLOAT:
2798 return 32
2799 elif t == DType.BOOL:
2800 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002801 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002802 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002803
2804 # Argument generators
2805 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
2806 # Where the string descriptor is used to generate the test name and
2807 # The build_fcn_arg_list is expanded and passed to the operator test
2808 # build function
2809
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002810 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
2811 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
2812
Matthew Haddon848efb42021-09-09 12:30:53 +01002813 # build_placeholder returns an int, ABS/other ops does not
2814 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002815 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
2816 return result_tens
2817 elif op['op'] == Op.IDENTITY:
2818 self.ser.addOperator(op['op'], a.name, result_tens.name, None, qinfo)
2819 return result_tens
2820
2821 # Ensure new output type has correct qinfo
2822 if error_name == ErrorIf.WrongOutputType:
2823 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
2824 qinfo = ts.TosaSerializerQuantInfo()
2825 qinfo.UnaryQuantInfo(
2826 TosaQuantGen.getQinfo(self, a.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
2827 )
2828
2829 # Invalidate Input/Output list for error if checks.
2830 input_list = [a.name]
2831 output_list = [result_tens.name]
2832 pCount, cCount = op["operands"]
2833 num_operands = pCount + cCount
2834 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2835
2836 TosaErrorValidator.evValidateErrorIfs(
2837 self.ser,
2838 validator_fcns,
2839 error_name,
2840 op=op,
2841 input_dtype=a.dtype,
2842 output_dtype=result_tens.dtype,
2843 qinfo = qinfo,
2844 result_tensor = result_tens,
2845 input_list=input_list,
2846 output_list=output_list,
2847 num_operands=num_operands,
2848 )
2849
2850 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002851 return result_tens
2852
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002853 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
2854 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
2855
2856
2857 # Invalidate Input/Output list for error if checks.
2858 input_list = [a.name, b.name]
2859 output_list = [result_tens.name]
2860 pCount, cCount = op["operands"]
2861 num_operands = pCount + cCount
2862 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2863
2864 TosaErrorValidator.evValidateErrorIfs(
2865 self.ser,
2866 validator_fcns,
2867 error_name,
2868 op=op,
2869 input1 = a,
2870 input2 = b,
2871 input_dtype = a.dtype,
2872 output_dtype = result_tens.dtype,
2873 result_tensor = result_tens,
2874 input_list=input_list,
2875 output_list=output_list,
2876 num_operands=num_operands,
2877 )
2878
2879 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07002880 return result_tens
2881
2882 def build_binary_nonbroadcast(self, op, a, b):
2883 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002884 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002885 return result_tens
2886
Kevin Chengaee1fac2020-11-11 13:54:06 -08002887 def build_arithmetic_right_shift(self, op, a, b, round):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002888 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002889
2890 attr = ts.TosaSerializerAttribute()
2891 attr.ArithmeticRightShiftAttribute(round)
2892
Matthew Haddon848efb42021-09-09 12:30:53 +01002893 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002894 return result_tens
2895
2896 def build_mul(self, op, a, b, shift):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002897 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Eric Kunzee5e26762020-10-13 16:11:07 -07002898
2899 # Special for multiply:
2900 # Force the result to INT32 for INT types
2901 if a.dtype != DType.FLOAT:
2902 result_tens.setDtype(DType.INT32)
2903
Kevin Chengaee1fac2020-11-11 13:54:06 -08002904 attr = ts.TosaSerializerAttribute()
2905 attr.MulAttribute(shift)
2906
Matthew Haddon848efb42021-09-09 12:30:53 +01002907 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002908 return result_tens
2909
2910 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002911 # Constant size depending on type, random values
2912 if a.dtype == DType.INT16:
Kevin Chengacb550f2021-06-29 15:32:19 -07002913 table_dtype = DType.INT16
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002914 table_arr = self.getRandTensor([513], table_dtype)
2915 else:
2916 assert a.dtype == DType.INT8
2917 table_dtype = DType.INT8
2918 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002919
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002920 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
2921 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01002922 self.ser.addOperator(op['op'], [a.name, table_tens.name], [result_tens.name], None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002923
2924 return result_tens
2925
2926 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07002927 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002928 self.ser.addOperator(op['op'], [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002929 return result_tens
2930
2931 def build_comparison(self, op, a, b):
2932 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002933 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002934 return result_tens
2935
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002936 def build_argmax(self, op, a, axis, validator_fcns, error_name):
2937 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
2938
2939 # Invalidate Input/Output list for error if checks.
2940 input_list = [a.name]
2941 output_list = [result_tens.name]
2942 pCount, cCount = op["operands"]
2943 num_operands = pCount + cCount
2944 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2945
2946 TosaErrorValidator.evValidateErrorIfs(
2947 self.ser,
2948 validator_fcns,
2949 error_name,
2950 op=op,
2951 axis=axis,
2952 input_shape = a.shape,
2953 input_dtype = a.dtype,
2954 output_shape = result_tens.shape,
2955 output_dtype = result_tens.dtype,
2956 result_tensor = result_tens,
2957 input_list=input_list,
2958 output_list=output_list,
2959 num_operands=num_operands,
2960 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002961
2962 attr = ts.TosaSerializerAttribute()
2963 attr.AxisAttribute(axis)
2964
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002965 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002966 return result_tens
2967
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002968 def build_pool2d(self, op, input, stride, pad, kernel, validator_fcns=None, error_name=None, qinfo=None):
2969 result_tens = OutputShaper.pool2dOp(self.ser, self.rng, input, kernel, stride, pad, error_name)
2970
2971 # Ensure new output type has correct qinfo
2972 if error_name == ErrorIf.WrongInputType:
2973 if input.dtype not in [DType.INT8, DType.UINT8]:
2974 qinfo = ts.TosaSerializerQuantInfo()
2975 qinfo.UnaryQuantInfo(
2976 TosaQuantGen.getQinfo(self, input.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
2977 )
2978
2979 # Invalidate Input/Output list for error if checks.
2980 input_list = [input.name]
2981 output_list = [result_tens.name]
2982 pCount, cCount = op["operands"]
2983 num_operands = pCount + cCount
2984 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2985
2986 TosaErrorValidator.evValidateErrorIfs(
2987 self.ser,
2988 validator_fcns,
2989 error_name,
2990 op=op,
2991 input_shape=input.shape,
2992 input_dtype=input.dtype,
2993 output_shape=result_tens.shape,
2994 output_dtype=result_tens.dtype,
2995 kernel=kernel,
2996 stride=stride,
2997 pad=pad,
2998 qinfo = qinfo,
2999 result_tensor = result_tens,
3000 input_list=input_list,
3001 output_list=output_list,
3002 num_operands=num_operands,
3003 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003004
3005 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003006 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07003007
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003008 self.ser.addOperator(op['op'], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003009 return result_tens
3010
3011 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003012 assert len(padding) == 4
3013 result_tens = OutputShaper.conv2dOp(
3014 self.ser, ifm, filter, strides, padding, dilations
3015 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003016
3017 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003018 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07003019
Kevin Cheng550ccc52021-03-03 11:21:43 -08003020 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003021 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003022 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003023 return result_tens
3024
Kevin Cheng1533b852021-09-01 12:51:58 -07003025 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
3026 assert len(padding) == 6
3027 result_tens = OutputShaper.conv3dOp(
3028 self.ser, ifm, filter, strides, padding, dilations
3029 )
3030
3031 attr = ts.TosaSerializerAttribute()
3032 attr.ConvAttribute(padding, strides, dilations)
3033
3034 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003035 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng1533b852021-09-01 12:51:58 -07003036 )
3037 return result_tens
3038
Kevin Cheng550ccc52021-03-03 11:21:43 -08003039 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07003040 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003041 ):
3042 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07003043 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
3044
3045 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003046 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003047
Kevin Cheng550ccc52021-03-03 11:21:43 -08003048 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003049 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003050 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003051 return result_tens
3052
Kevin Cheng550ccc52021-03-03 11:21:43 -08003053 def build_depthwise_conv2d(
3054 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
3055 ):
3056 result_tens = OutputShaper.depthwiseConv2dOp(
3057 self.ser, ifm, filter, strides, padding, dilations
3058 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003059
3060 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003061 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07003062
Kevin Cheng550ccc52021-03-03 11:21:43 -08003063 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003064 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003065 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003066 return result_tens
3067
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003068 def build_fully_connected(self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None):
3069 result_tens = OutputShaper.fullyConnectedOp(self.ser, self.rng, ifm, filter, error_name)
3070
3071 # Invalidate Input/Output list for error if checks.
3072 input_list = [ifm.name, filter.name, bias.name]
3073 output_list = [result_tens.name]
3074 pCount, cCount = op["operands"]
3075 num_operands = pCount + cCount
3076 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3077
3078 TosaErrorValidator.evValidateErrorIfs(
3079 self.ser,
3080 validator_fcns,
3081 error_name,
3082 op=op,
3083 input_shape=ifm.shape,
3084 input_dtype=ifm.dtype,
3085 weight_dtype=filter.dtype,
3086 output_shape=result_tens.shape,
3087 output_dtype=result_tens.dtype,
3088 qinfo = qinfo,
3089 result_tensor = result_tens,
3090 input_list=input_list,
3091 output_list=output_list,
3092 num_operands=num_operands,
3093 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003094
Kevin Cheng550ccc52021-03-03 11:21:43 -08003095 self.ser.addOperator(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003096 op['op'], input_list, output_list, None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003097 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003098 return result_tens
3099
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003100 def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
3101 result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
3102
3103 # Invalidate Input/Output list for error if checks.
3104 input_list = [a.name, b.name]
3105 output_list = [result_tens.name]
3106 pCount, cCount = op["operands"]
3107 num_operands = pCount + cCount
3108 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3109
3110 TosaErrorValidator.evValidateErrorIfs(
3111 self.ser,
3112 validator_fcns,
3113 error_name,
3114 op=op,
3115 input_shape=a.shape,
3116 input_dtype=a.dtype,
3117 input2_shape=b.shape,
3118 input2_dtype=b.dtype,
3119 output_shape=result_tens.shape,
3120 output_dtype=result_tens.dtype,
3121 qinfo = qinfo,
3122 result_tensor = result_tens,
3123 input_list=input_list,
3124 output_list=output_list,
3125 num_operands=num_operands,
3126 )
3127
3128 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003129 return result_tens
3130
Matthew Haddond6ce7252021-09-29 15:35:44 +01003131 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
3132 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
3133
3134 # Invalidate Input/Output list for error if checks.
3135 input_list = [a.name]
3136 output_list = [result_tens.name]
3137 pCount, cCount = op["operands"]
3138 num_operands = pCount + cCount
3139 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3140
3141 TosaErrorValidator.evValidateErrorIfs(
3142 self.ser,
3143 validator_fcns,
3144 error_name,
3145 op=op,
3146 axis = axis,
3147 input_shape = a.shape,
3148 output_shape = result_tens.shape,
3149 input_dtype = a.dtype,
3150 output_dtype = result_tens.dtype,
3151 result_tensor = result_tens,
3152 input_list=input_list,
3153 output_list=output_list,
3154 num_operands=num_operands,
3155 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003156
3157 attr = ts.TosaSerializerAttribute()
3158 attr.AxisAttribute(axis)
3159
Matthew Haddond6ce7252021-09-29 15:35:44 +01003160 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003161 return result_tens
3162
3163 def build_clamp(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003164 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003165
3166 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01003167 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07003168
3169 if a.dtype == DType.FLOAT:
3170 attr.ClampAttribute(0, 0, min(v), max(v))
3171 else:
3172 attr.ClampAttribute(min(v), max(v), 0, 0)
3173
Matthew Haddon848efb42021-09-09 12:30:53 +01003174 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003175 return result_tens
3176
3177 def build_leaky_relu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003178 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003179 attr = ts.TosaSerializerAttribute()
3180
3181 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
3182
Matthew Haddon848efb42021-09-09 12:30:53 +01003183 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003184 return result_tens
3185
3186 # Needs an additional type/input
3187 def build_prelu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003188 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003189
Matthew Haddon848efb42021-09-09 12:30:53 +01003190 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003191 return result_tens
3192
Eric Kunzee5e26762020-10-13 16:11:07 -07003193 def build_sigmoid(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003194 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01003195 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003196 return result_tens
3197
3198 def build_tanh(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003199 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01003200 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003201 return result_tens
3202
Matthew Haddon818ab902021-07-27 09:12:49 +01003203 def build_concat(self, op, *a):
Kevin Cheng93a16282021-08-31 16:14:03 -07003204 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01003205
3206 # To store variable length list of input tensors we need to store axis along with it
3207 axis = a[-1]
3208 a = a[:-1]
3209
3210 result_tens = OutputShaper.concatOp(self.ser, axis, *a)
Eric Kunzee5e26762020-10-13 16:11:07 -07003211
3212 attr = ts.TosaSerializerAttribute()
3213 attr.AxisAttribute(axis)
3214
Matthew Haddon818ab902021-07-27 09:12:49 +01003215 input_tensor_names = []
3216 for tensor in a:
3217 input_tensor_names.append(tensor.name)
3218
Matthew Haddon848efb42021-09-09 12:30:53 +01003219 self.ser.addOperator(op['op'], input_tensor_names, [result_tens.name], attr)
3220 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07003221
Matthew Haddone807aae2021-10-11 18:12:58 +01003222 def build_pad(self, op, a, padding, validator_fcns=None, error_name=None, qinfo=None):
3223 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003224
3225 # Need to turn the padding array into a TOSA tensor here.
3226 # This is one of the few tensor operands that does not get
3227 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08003228 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07003229
Matthew Haddone807aae2021-10-11 18:12:58 +01003230 # Invalidate Input/Output list for error if checks.
3231 input_list = [a.name, padding_tens.name]
3232 output_list = [result_tens.name]
3233 pCount, cCount = op["operands"]
3234 num_operands = pCount + cCount
3235 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3236
3237 TosaErrorValidator.evValidateErrorIfs(
3238 self.ser,
3239 validator_fcns,
3240 error_name,
3241 op=op,
3242 input_shape = a.shape,
3243 output_shape = result_tens.shape,
3244 input_dtype = a.dtype,
3245 output_dtype = result_tens.dtype,
3246 pad=padding,
3247 qinfo=qinfo,
3248 result_tensor = result_tens,
3249 input_list=input_list,
3250 output_list=output_list,
3251 num_operands=num_operands,
3252 )
3253
Kevin Cheng550ccc52021-03-03 11:21:43 -08003254 self.ser.addOperator(
Matthew Haddone807aae2021-10-11 18:12:58 +01003255 op['op'], input_list, output_list, None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003256 )
Matthew Haddone86fd342021-09-07 16:12:21 +01003257 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07003258
Matthew Haddone807aae2021-10-11 18:12:58 +01003259 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
3260 result_tens = OutputShaper.reshapeOp(self.ser, self.rng, a, newShape, error_name)
3261
3262 # Invalidate Input/Output list for error if checks.
3263 input_list = [a.name]
3264 output_list = [result_tens.name]
3265 pCount, cCount = op["operands"]
3266 num_operands = pCount + cCount
3267 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3268
3269 TosaErrorValidator.evValidateErrorIfs(
3270 self.ser,
3271 validator_fcns,
3272 error_name,
3273 op=op,
3274 input_shape = a.shape,
3275 output_shape = result_tens.shape,
3276 input_dtype = a.dtype,
3277 output_dtype = result_tens.dtype,
3278 result_tensor = result_tens,
3279 input_list=input_list,
3280 output_list=output_list,
3281 num_operands=num_operands,
3282 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003283
3284 attr = ts.TosaSerializerAttribute()
3285 attr.ReshapeAttribute(newShape)
3286
Matthew Haddone807aae2021-10-11 18:12:58 +01003287 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003288 return result_tens
3289
3290 def build_reverse(self, op, a, axis):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003291 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003292
3293 attr = ts.TosaSerializerAttribute()
3294 attr.AxisAttribute(axis)
3295
Matthew Haddon848efb42021-09-09 12:30:53 +01003296 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003297 return result_tens
3298
Matthew Haddone807aae2021-10-11 18:12:58 +01003299 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
3300 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003301
Kevin Cheng550ccc52021-03-03 11:21:43 -08003302 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07003303
Matthew Haddone807aae2021-10-11 18:12:58 +01003304 # Invalidate Input/Output list for error if checks.
3305 input_list = [a.name, perms_tens.name]
3306 output_list = [result_tens.name]
3307 pCount, cCount = op["operands"]
3308 num_operands = pCount + cCount
3309 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3310
3311 TosaErrorValidator.evValidateErrorIfs(
3312 self.ser,
3313 validator_fcns,
3314 error_name,
3315 op=op,
3316 input_shape = a.shape,
3317 output_shape = result_tens.shape,
3318 perms=perms,
3319 input_dtype = a.dtype,
3320 output_dtype = result_tens.dtype,
3321 result_tensor = result_tens,
3322 input_list=input_list,
3323 output_list=output_list,
3324 num_operands=num_operands,
3325 )
3326
3327
3328 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07003329 return result_tens
3330
Matthew Haddone807aae2021-10-11 18:12:58 +01003331 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
3332 result_tens = OutputShaper.sliceOp(self.ser, self.rng, a, start, size, error_name)
3333
3334 # Invalidate Input/Output list for error if checks.
3335 input_list = [a.name]
3336 output_list = [result_tens.name]
3337 pCount, cCount = op["operands"]
3338 num_operands = pCount + cCount
3339 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3340
3341 TosaErrorValidator.evValidateErrorIfs(
3342 self.ser,
3343 validator_fcns,
3344 error_name,
3345 op=op,
3346 input_shape = a.shape,
3347 output_shape = result_tens.shape,
3348 input_dtype = a.dtype,
3349 output_dtype = result_tens.dtype,
3350 start=start,
3351 size=size,
3352 result_tensor = result_tens,
3353 input_list=input_list,
3354 output_list=output_list,
3355 num_operands=num_operands,
3356 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003357
3358 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01003359 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07003360
Matthew Haddone807aae2021-10-11 18:12:58 +01003361 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003362 return result_tens
3363
3364 def build_tile(self, op, a, multiples):
3365 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
3366
3367 attr = ts.TosaSerializerAttribute()
3368 attr.TileAttribute(multiples)
3369
Matthew Haddon848efb42021-09-09 12:30:53 +01003370 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003371 return result_tens
3372
Kevin Cheng77d0f762020-11-24 10:26:32 -08003373 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07003374
3375 # Create a new indicies tensor
3376 # here with data that doesn't exceed the dimensions of the values tensor
3377
Kevin Cheng550ccc52021-03-03 11:21:43 -08003378 K = values.shape[1] # K
3379 W = self.randInt(
3380 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
3381 ) # W
3382 indicies_arr = np.int32(
3383 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
3384 ) # (N, W)
3385 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003386
Kevin Cheng77d0f762020-11-24 10:26:32 -08003387 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07003388
Matthew Haddon848efb42021-09-09 12:30:53 +01003389 self.ser.addOperator(op['op'], [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003390
3391 return result_tens
3392
Kevin Cheng77d0f762020-11-24 10:26:32 -08003393 def build_scatter(self, op, values_in, input):
3394
3395 # Create a new indicies tensor
3396 # here with data that doesn't exceed the dimensions of the values_in tensor
3397
Kevin Cheng550ccc52021-03-03 11:21:43 -08003398 K = values_in.shape[1] # K
3399 W = input.shape[1] # W
3400 indicies_arr = np.int32(
3401 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
3402 ) # (N, W)
3403 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08003404
3405 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
3406
Kevin Cheng550ccc52021-03-03 11:21:43 -08003407 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003408 op['op'], [values_in.name, indicies.name, input.name], [result_tens.name]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003409 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08003410
3411 return result_tens
3412
Matthew Haddon848efb42021-09-09 12:30:53 +01003413
Kevin Cheng550ccc52021-03-03 11:21:43 -08003414 def build_resize(
3415 self,
3416 op,
3417 input,
3418 mode,
3419 stride,
3420 offset,
3421 shift,
3422 stride_fp,
3423 offset_fp,
3424 output_dims,
3425 input_dtype,
3426 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01003427 validator_fcns,
3428 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003429 ):
3430 result_tens = OutputShaper.resizeOp(
3431 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003432 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003433 input,
3434 mode,
3435 stride,
3436 offset,
3437 shift,
3438 stride_fp,
3439 offset_fp,
3440 output_dims,
3441 input_dtype,
3442 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01003443 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08003444 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003445
Matthew Haddon848efb42021-09-09 12:30:53 +01003446 # Invalidate Input/Output list for error if checks.
3447 input_list = [input.name]
3448 output_list = [result_tens.name]
3449 pCount, cCount = op["operands"]
3450 num_operands = pCount + cCount
3451 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
Matthew Haddone86fd342021-09-07 16:12:21 +01003452
Matthew Haddon848efb42021-09-09 12:30:53 +01003453 TosaErrorValidator.evValidateErrorIfs(
3454 self.ser,
3455 validator_fcns,
3456 error_name,
3457 op=op,
3458 mode=mode,
3459 shift=shift,
3460 input_dtype=input_dtype,
3461 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003462 input_shape=input.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01003463 output_shape=output_dims,
3464 offset=offset,
3465 offset_fp=offset_fp,
3466 stride=stride,
3467 stride_fp=stride_fp,
3468 input_list=input_list,
3469 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003470 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01003471 num_operands=num_operands,
3472 )
Matthew Haddone86fd342021-09-07 16:12:21 +01003473
Eric Kunzee5e26762020-10-13 16:11:07 -07003474 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08003475
Kevin Cheng550ccc52021-03-03 11:21:43 -08003476 attr.ResizeAttribute(
3477 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
3478 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003479
Matthew Haddon848efb42021-09-09 12:30:53 +01003480 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003481 return result_tens
3482
3483 def build_identityn(self, op, val, val2):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003484 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, None)
3485 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, None)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003486 self.ser.addOperator(
3487 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
3488 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003489 return result_tens
3490
Kevin Cheng17e92022021-10-01 14:33:33 -07003491 def build_const(self, op, val):
3492 self.ser.addOutputTensor(val)
3493 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07003494
3495 # Type Conversion
3496 def build_cast(self, op, val, out_dtype):
3497 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01003498 self.ser.addOperator(op['op'], [val.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003499 return result_tens
3500
Matthew Haddonc2025212021-10-08 21:21:05 +01003501 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel, validator_fcns, error_name):
Eric Kunzee5e26762020-10-13 16:11:07 -07003502 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
3503
3504 if per_channel:
3505 nc = val.shape[-1]
3506 else:
3507 nc = 1
3508
3509 in_type_width = self.typeWidth(val.dtype)
3510 out_type_width = self.typeWidth(out_dtype)
3511
Kevin Cheng3a478572021-01-22 17:21:02 -08003512 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003513 input_zp = self.randInt(-128, 128)
3514 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07003515 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003516 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07003517 in_type_width = in_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01003518 elif error_name == ErrorIf.InputZeroPointNotZero:
3519 input_zp = self.randInt(-128, 128)
3520 if input_zp == 0:
3521 input_zp = input_zp + self.rng.integers(1, 10)
3522 in_type_width = in_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003523 else:
3524 input_zp = 0
3525
Kevin Cheng3a478572021-01-22 17:21:02 -08003526 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003527 output_zp = self.randInt(-128, 128)
3528 out_type_width = out_type_width + 1
3529 elif out_dtype == DType.UINT8:
3530 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07003531 out_type_width = out_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01003532 elif error_name == ErrorIf.OutputZeroPointNotZero:
3533 output_zp = self.randInt(-128, 128)
3534 if output_zp == 0:
3535 output_zp = output_zp + self.rng.integers(1, 10)
3536 out_type_width = out_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003537 else:
3538 output_zp = 0
3539
3540 # Calculate scale based on:
3541 # scale = a *(2^output_width)/(2^input_width))
3542
3543 a = np.float32(self.rng.random(size=[nc]))
3544 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
3545
3546 if scale32:
3547 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01003548 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07003549 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
3550 else:
3551 # Cap the scaling at 2^15 - 1 for scale16
3552 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
3553
Kevin Cheng550ccc52021-03-03 11:21:43 -08003554 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07003555
3556 multiplier_arr = np.int32(np.zeros(shape=[nc]))
3557 shift_arr = np.int32(np.zeros(shape=[nc]))
3558
3559 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003560 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
3561 scale_arr[i], scale32
3562 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003563
Kevin Cheng550ccc52021-03-03 11:21:43 -08003564 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07003565
Matthew Haddonc2025212021-10-08 21:21:05 +01003566 # Invalidate Input/Output list for error if checks.
3567 input_list = [val.name]
3568 output_list = [result_tens.name]
3569 pCount, cCount = op["operands"]
3570 num_operands = pCount + cCount
3571 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3572
3573 qinfo = (input_zp, output_zp)
3574 TosaErrorValidator.evValidateErrorIfs(
3575 self.ser,
3576 validator_fcns,
3577 error_name,
3578 op=op,
3579 input_dtype=val.dtype,
3580 output_dtype=out_dtype,
3581 input_shape=val.shape,
3582 qinfo=qinfo,
3583 scale32 = scale32,
3584 double_round = double_round,
3585 input_list=input_list,
3586 output_list=output_list,
3587 result_tensor=result_tens,
3588 num_operands=num_operands,
3589 )
3590
Eric Kunzee5e26762020-10-13 16:11:07 -07003591 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003592 attr.RescaleAttribute(
3593 input_zp,
3594 output_zp,
3595 multiplier_arr,
3596 shift_arr,
3597 scale32,
3598 double_round,
3599 per_channel,
3600 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003601
Matthew Haddonc2025212021-10-08 21:21:05 +01003602 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003603 return result_tens
3604
3605 def build_cond_if_const(self, op, then_tens, else_tens, cond):
3606 # For cond_if with constants, we're supplied with then/else tensors that we ignore
3607 # (except for the generated shap) and the condition. Build Then/Else blocks
3608 # and fill them with const nodes for the body.
3609
3610 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003611 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07003612
3613 # Make then/else tensors
3614 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01003615 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
3616 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003617
3618 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08003619 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07003620
3621 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003622 then_block = "THEN_BLOCK"
3623 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003624 attr = ts.TosaSerializerAttribute()
3625 attr.CondIfAttribute(then_block, else_block)
3626
3627 # Finally, build the op and the two blocks
Matthew Haddon848efb42021-09-09 12:30:53 +01003628 self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003629
3630 self.ser.startBasicBlock(then_block)
3631 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003632 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003633 self.ser.addOutputTensor(then_tens)
3634
3635 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003636 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003637 self.ser.addOutputTensor(else_tens)
3638
3639 return result_tens
3640
3641 def build_cond_if_binary(self, op, a, b, cond):
3642 # For cond_if with a binary op in the then/else blocks, take a and b and
3643 # alternately add or subtract them based on the condition
3644
3645 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003646 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07003647
Kevin Cheng550ccc52021-03-03 11:21:43 -08003648 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003649
3650 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003651 then_block = "THEN_BLOCK"
3652 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003653 attr = ts.TosaSerializerAttribute()
3654 attr.CondIfAttribute(then_block, else_block)
3655
3656 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003657 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003658 op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08003659 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003660
Les Bell6040b4d2021-10-11 12:50:31 +01003661 if a.dtype in (DType.FLOAT, DType.INT32):
3662 then_op, else_op = Op.ADD, Op.SUB
3663 elif a.dtype in (DType.INT8, DType.INT16):
3664 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
3665 else:
3666 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07003667
Les Bell6040b4d2021-10-11 12:50:31 +01003668 for block, op in ((then_block, then_op), (else_block, else_op)):
3669 self.ser.startBasicBlock(block)
3670 self.ser.addInputTensor(a)
3671 self.ser.addInputTensor(b)
3672 tens = self.ser.addOutput(a.shape, a.dtype)
3673 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003674
3675 return result_tens
3676
3677 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003678 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07003679
Kevin Cheng550ccc52021-03-03 11:21:43 -08003680 cond_block = "COND_BLOCK"
3681 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003682
3683 attr = ts.TosaSerializerAttribute()
3684 attr.WhileLoopAttribute(cond_block, body_block)
3685
3686 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003687 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003688 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08003689 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07003690
3691 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003692 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
3693 a_out = self.ser.addIntermediate(a.shape, a.dtype)
3694 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003695
3696 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08003697 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003698 op['op'],
Kevin Cheng550ccc52021-03-03 11:21:43 -08003699 [iter.name, a.name, acc.name],
3700 [iter_out.name, a_out.name, acc_out.name],
3701 attr,
3702 )
Kevin Chengb227ae52021-09-02 13:43:17 -07003703 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07003704
3705 # COND block (input: iter, output: cond_tens )
3706 self.ser.startBasicBlock(cond_block)
3707 self.ser.addInputTensor(iter)
3708 self.ser.addInputTensor(a)
3709 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003710 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
3711 cond_tens = self.ser.addOutput([], DType.BOOL)
3712 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003713
3714 # BODY block (input: a, acc, iter, output: a, acc, iter)
3715 # Note that local intermediate tensors need to be declared here for the outputs
3716 self.ser.startBasicBlock(body_block)
3717 self.ser.addInputTensor(iter)
3718 self.ser.addInputTensor(a)
3719 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003720 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
3721 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
3722 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003723 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
3724 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
3725 self.ser.addOutputTensor(iter_body_out)
3726 self.ser.addOutputTensor(a)
3727 self.ser.addOutputTensor(acc_body_out)
3728
3729 return acc_out
3730
Matthew Haddon1c00b712021-10-01 15:51:03 +01003731 def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
3732 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
3733 default_test_rank_range = range(1, 5)
3734 if not shapeFilter:
3735 shapeFilter = [None]
3736
3737 # Calculate the filters based on what is requested and what the operator allows
3738 rmin, rmax = op["rank"]
3739 if rankFilter is not None:
3740 cleanRankFilter = []
3741 # Ensure rankFilter values are allowed by operator
3742 for rank in rankFilter:
3743 if rank >= rmin and rank <= rmax:
3744 cleanRankFilter.append(rank)
3745 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01003746 # Ensure default behaviour is bounded by default range or by operator,
3747 # whichever is the smaller range of ranks.
3748 opRankRange = range(rmin, rmax + 1)
3749 cleanRankFilter = opRankRange if len(opRankRange) <= len(default_test_rank_range) else default_test_rank_range
Matthew Haddon1c00b712021-10-01 15:51:03 +01003750 else:
3751 cleanRankFilter = range(rmin, rmax + 1)
3752
3753 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003754
Matthew Haddon1c00b712021-10-01 15:51:03 +01003755 if dtypeFilter is not None:
3756 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01003757 # Create list of operator dtypes filtered by requested dtypes
3758 for dtype in dtypes:
3759 if dtype in dtypeFilter or (isinstance(dtype, list) and dtype[0] in dtypeFilter):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003760 cleanDtypeFilter.append(dtype)
3761 else:
3762 cleanDtypeFilter = dtypes
3763
3764 if testType == 'positive':
3765 filterDict = {
3766 'shapeFilter': shapeFilter,
3767 'rankFilter': cleanRankFilter,
3768 'dtypeFilter': cleanDtypeFilter
3769 }
3770 return filterDict
3771 elif testType == 'negative':
Matthew Haddone807aae2021-10-11 18:12:58 +01003772 if validator is not None:
3773 validator_info = validator(check=False, op=op)
3774 else:
3775 return None
3776
Matthew Haddon1c00b712021-10-01 15:51:03 +01003777 error_arguments = validator_info['param_reqs']
3778
3779 #Set parameters as required
3780 if error_arguments['rank'] != None:
3781 rankFilter = error_arguments['rank']
3782 else:
3783 rankFilter = cleanRankFilter
3784
3785 if error_arguments['dtype'] != None:
3786 dtypeFilter = error_arguments['dtype']
3787 else:
3788 dtypeFilter = cleanDtypeFilter
3789
3790 if error_arguments['shape'] != None:
3791 shapeFilter = error_arguments['shape']
3792 else:
3793 shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small
3794
3795 filterDict = {
3796 'shapeFilter': shapeFilter,
3797 'rankFilter': rankFilter,
3798 'dtypeFilter': dtypeFilter
3799 }
3800 return filterDict
3801
3802
Kevin Cheng550ccc52021-03-03 11:21:43 -08003803 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01003804 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08003805 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003806
3807 try:
3808 op = self.TOSA_OP_LIST[opName]
3809 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003810 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003811
3812 # Initialize a new random number generator
3813 self.rng = np.random.default_rng(self.random_seed)
3814
Kevin Cheng550ccc52021-03-03 11:21:43 -08003815 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003816
Eric Kunzee5e26762020-10-13 16:11:07 -07003817 # Test list consists of a tuple of:
3818 # (opName, testNameStr, dtype, shapeList, argumentsList)
3819 testList = []
Matthew Haddon1c00b712021-10-01 15:51:03 +01003820 if testType == 'negative' and "error_if_validators" in op:
3821 error_if_validators = op["error_if_validators"]
3822 else:
3823 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07003824
Matthew Haddon1c00b712021-10-01 15:51:03 +01003825 for validator in error_if_validators:
3826 if validator is not None:
3827 error_name = validator(check=False, op=op)['error_name']
Matthew Haddon1c00b712021-10-01 15:51:03 +01003828 else:
3829 error_name = None
3830
3831 filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator)
Matthew Haddone807aae2021-10-11 18:12:58 +01003832 if filterDict == None:
3833 return []
Matthew Haddon1c00b712021-10-01 15:51:03 +01003834 cleanRankFilter = filterDict['rankFilter']
3835 cleanDtypeFilter = filterDict['dtypeFilter']
3836 cleanShapeFilter = filterDict['shapeFilter']
3837 #print(f"Filters: S {shapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
3838
3839 for r in cleanRankFilter:
Kevin Cheng1533b852021-09-01 12:51:58 -07003840 if opName.startswith("conv3d"):
3841 assert r == 5, "conv3d test must have input rank == 5"
Matthew Haddon1c00b712021-10-01 15:51:03 +01003842 for t in cleanDtypeFilter:
3843 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01003844 # Filter out by rank
3845 if shape is not None and len(shape) != r:
3846 continue
Matthew Haddon74567092021-07-16 15:38:20 +01003847 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003848 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003849
Matthew Haddon74567092021-07-16 15:38:20 +01003850 shapeStr = self.shapeStr(shapeList[0])
3851 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07003852
Matthew Haddon74567092021-07-16 15:38:20 +01003853 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
3854 argList = []
3855 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01003856 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003857 else:
Matthew Haddon74567092021-07-16 15:38:20 +01003858 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07003859
Matthew Haddon74567092021-07-16 15:38:20 +01003860 for argStr, args in argList:
Matthew Haddon1c00b712021-10-01 15:51:03 +01003861 if testType == 'positive':
3862 if argStr:
3863 testStr = "{}_{}_{}_{}".format(
3864 opName, shapeStr, typeStr, argStr
3865 )
3866 else:
3867 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
3868 elif testType == 'negative':
Matthew Haddone86fd342021-09-07 16:12:21 +01003869 if argStr:
3870 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
3871 opName, error_name, shapeStr, typeStr, argStr
3872 )
3873 else:
3874 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003875
3876 testList.append((opName, testStr, t, error_name, shapeList, args))
3877
3878 if testType == 'positive':
3879 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
3880 if "invalid_test_validators" in op:
3881 invalid_test_validators = op["invalid_test_validators"]
3882 clean_testList = []
3883 for test in testList:
3884 for validator_fcn in invalid_test_validators:
3885 remove_test = False
3886 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
3887 remove_test = True
3888 if not remove_test:
3889 clean_testList.append(test)
3890 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07003891
3892 return testList
3893
Matthew Haddone86fd342021-09-07 16:12:21 +01003894
3895 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07003896 try:
3897 op = self.TOSA_OP_LIST[opName]
3898 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003899 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003900
3901 # Create a serializer
3902 self.createSerializer(opName, testStr)
3903
Kevin Cheng550ccc52021-03-03 11:21:43 -08003904 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003905 if "error_if_validators" in op:
3906 error_if_validators = op["error_if_validators"]
3907 else:
3908 error_if_validators = None
3909
Kevin Cheng550ccc52021-03-03 11:21:43 -08003910 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003911 num_operands = pCount + cCount
3912
3913 if isinstance(dtype_or_dtypeList, list):
3914 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07003915 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01003916 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003917 else:
3918 dtypeList = [dtype_or_dtypeList] * (num_operands)
3919
Kevin Cheng93a16282021-08-31 16:14:03 -07003920 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01003921 assert (
3922 len(shapeList) == num_operands
3923 ), "shapeList length {} must match number of operands {}".format(
3924 len(shapeList), num_operands
3925 )
3926 assert (
3927 len(dtypeList) == num_operands
3928 ), "dtypeList length {} must match number of operands {}".format(
3929 len(dtypeList), num_operands
3930 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003931
3932 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003933 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003934 except KeyError:
3935 qgen = None
3936
3937 # Build the random tensor operands and the test
3938 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08003939
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003940 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003941
3942 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003943 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003944 else:
3945 qinfo = None
3946
3947 try:
3948 if error_if_validators is None:
3949 if qinfo is not None:
3950 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
3951 else:
3952 resultName = build_fcn(self, op, *tens, *testArgs)
3953 else:
3954 if qinfo is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003955 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name, qinfo)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003956 else:
3957 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name)
3958 except TypeError as e:
3959 print(
3960 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
3961 build_fcn, tens, testArgs
3962 )
3963 )
3964 raise e
3965
3966 if resultName is None:
3967 print("Invalid ERROR_IF tests created")
3968
3969 # Save the serialized test
3970 self.serialize("test")
3971
3972
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003973 def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003974 pCount, cCount = op["operands"]
3975
3976 tens = []
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003977 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name == None:
Jeremy Johnsonef509a42021-09-07 13:59:47 +01003978 # Make sure the operation does not cause value saturation - where
3979 # the number wraps due to limited number of bits to store the answer
3980 assert (
3981 pCount == 2 and cCount == 0
3982 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01003983 placeholders = []
3984 add = (op["op"] == Op.ADD)
3985 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
3986 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
3987 if add:
3988 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
3989 else:
3990 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
3991
3992 # Work out the saturation limits
3993 max_i32 = (1 << 31)-1
3994 min_i32 = -(1 << 31)
3995 max_arr = np.full(shapeList[1], max_i32)
3996 min_arr = np.full(shapeList[1], min_i32)
3997
3998 # Find how much values exceed the maximum/minimums
3999 sat_max_arr = np.maximum(res_arr - max_arr, 0)
4000 sat_min_arr = np.minimum(res_arr - min_arr, 0)
4001
4002 if not add:
4003 # Swap saturation values and negate values as we need to perform opposite operations
4004 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
4005
4006 # Create new array of unsaturated values by clipping values as needed
4007 b_unsat_arr = b_arr
4008 if (sat_max_arr != 0).any():
4009 # Clip values that cause saturation
4010 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
4011 # Reduce axes in unsaturated tensor to match original tensor
4012 for axis, dim in enumerate(b_arr.shape):
4013 if dim != b_unsat_arr.shape[axis]:
4014 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
4015 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
4016
4017 if (sat_min_arr != 0).any():
4018 # Clip values that cause saturation
4019 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
4020 # Reduce axes in unsaturated tensor to match original tensor
4021 for axis, dim in enumerate(b_arr.shape):
4022 if dim != b_unsat_arr.shape[axis]:
4023 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
4024 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
4025
4026 placeholders.append(
4027 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
4028 )
4029 placeholders.append(
4030 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
4031 )
4032
4033 tens.extend(placeholders)
Jeremy Johnson8c06a652021-10-20 15:51:11 +01004034 elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[0] == DType.INT32:
4035 # Limit input tensors with cond_if_binary or while_loop to stop
4036 # saturation of add/sub ops
4037 pRemain = pCount
4038 placeholders = []
4039 for idx, shape in enumerate(shapeList[:]):
4040 arr = self.getRandTensor(shapeList[idx], DType.INT16)
4041 if pRemain > 0:
4042 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
4043 pRemain -= 1
4044 else:
4045 placeholders.append(self.ser.addConst(shape, dtypeList[idx], arr))
4046
4047 tens.extend(placeholders)
Jeremy Johnsonef509a42021-09-07 13:59:47 +01004048 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
4049 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004050 assert (
4051 pCount == 2 and cCount == 0
4052 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08004053
4054 placeholders = []
4055 for idx, shape in enumerate(shapeList[:]):
4056 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07004057 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08004058 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07004059 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08004060 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07004061 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08004062 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
4063 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004064 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08004065 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004066 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07004067 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08004068
4069 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01004070 elif op["op"] == Op.SELECT:
4071 # Set datatype of condition tensor to boolean
4072 dtypeList[0] = DType.BOOL
4073 tens.extend(
4074 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
4075 )
4076 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004077 elif op["op"] == Op.INTDIV and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004078 assert (
4079 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01004080 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004081
4082 placeholders = []
4083
Matthew Haddon459443c2021-08-23 16:43:13 +01004084 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004085 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07004086 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004087 while True:
4088 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
4089 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
4090
4091 if (divisor_arr == 0).any():
4092 continue
4093
Kevin Cheng47315e12021-05-13 17:41:28 -07004094 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004095 continue
4096
4097 break
4098
4099 placeholders.append(
4100 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
4101 )
4102 placeholders.append(
4103 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
4104 )
4105
4106 tens.extend(placeholders)
4107 elif op["op"] == Op.MUL:
4108 assert (
4109 pCount == 2 and cCount == 0
4110 ), "Op.MUL must have 2 placeholders, 0 consts"
4111
4112 if dtypeList[0] == DType.FLOAT:
4113 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
4114 else:
4115 placeholders = []
4116
4117 # Make sure multiply result in int32 range
4118 shift = testArgs[0]
4119 if dtypeList[0] == DType.INT8:
4120 num_bits = 8
4121 elif dtypeList[0] == DType.INT16:
4122 num_bits = 16
4123 elif dtypeList[0] == DType.INT32:
4124 num_bits = 32
4125 else:
4126 raise Exception("OpMul: invalid input dtype")
4127
4128 for idx, shape in enumerate(shapeList[:]):
4129 low = -(2 ** (num_bits - 1))
4130 high = (2 ** (num_bits - 1)) - 1
4131
4132 a_arr = np.int32(
4133 self.rng.integers(low=low, high=high, size=shapeList[0])
4134 )
4135 b_arr = np.int32(
4136 self.rng.integers(low=low, high=high, size=shapeList[1])
4137 )
4138
4139 i = 0
4140 while True:
4141
4142 a_arr_64 = a_arr.astype(np.int64)
4143 b_arr_64 = b_arr.astype(np.int64)
4144
4145 if shift > 0:
4146 rounding = 1 << (shift - 1)
4147 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
4148 else:
4149 result_arr = a_arr_64 * b_arr_64
4150
4151 if (result_arr > -(2 ** 31)).all() and (
4152 result_arr <= ((2 ** 31) - 1)
4153 ).all():
4154 break
4155
4156 i = i + 1
4157 a_arr = a_arr // 2
4158 b_arr = b_arr // 2
4159
4160 placeholders.append(
4161 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
4162 )
4163 placeholders.append(
4164 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
4165 )
4166
4167 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01004168 elif op["op"] == Op.CONCAT:
4169 count = len(shapeList) - self.args.num_const_inputs_concat
4170 if count < 1:
4171 count = 1
4172 if self.args.num_const_inputs_concat == 0:
4173 count = len(shapeList)
4174
4175 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
4176 tens.extend(
4177 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
4178 )
4179 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08004180 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07004181 tens.extend(
4182 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
4183 )
4184 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07004185
Matthew Haddon1c00b712021-10-01 15:51:03 +01004186 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07004187
4188 def createDynamicOpLists(self):
4189
4190 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07004191 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004192
Kevin Cheng1533b852021-09-01 12:51:58 -07004193 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004194 testName = "conv2d_{}x{}".format(k[0], k[1])
4195 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
4196 self.TOSA_OP_LIST[testName]["filter"] = k
4197 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07004198
Kevin Cheng550ccc52021-03-03 11:21:43 -08004199 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
4200 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
4201 "depthwise_conv2d_TEMPLATE"
4202 ].copy()
4203 self.TOSA_OP_LIST[testName]["filter"] = k
4204 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07004205
Kevin Cheng550ccc52021-03-03 11:21:43 -08004206 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
4207 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
4208 "transpose_conv2d_TEMPLATE"
4209 ].copy()
4210 self.TOSA_OP_LIST[testName]["filter"] = k
4211 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07004212
Kevin Cheng1533b852021-09-01 12:51:58 -07004213 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
4214 for k in KERNELS_3D:
4215 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
4216 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
4217 self.TOSA_OP_LIST[testName]["filter"] = k
4218 self.TOSA_OP_LIST[testName]["template"] = False
4219
Eric Kunzee5e26762020-10-13 16:11:07 -07004220 # Delete any templates after having created any dynamic ops
4221 # This is a two-pass operation because it's bad practice to delete
4222 # keys from dictionaries while iterating
4223 keyList = []
4224 for k in self.TOSA_OP_LIST:
4225 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004226 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07004227 keyList.append(k)
4228 continue
4229 except KeyError:
4230 pass
4231
4232 for k in keyList:
4233 del self.TOSA_OP_LIST[k]
4234
4235 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004236 """Fill in default fields for ops if they aren't already specified.
4237 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07004238 for op in self.TOSA_OP_LIST:
4239
4240 # Required fields
4241 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004242 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004243 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004244 raise Exception(
4245 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
4246 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004247
4248 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004249 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004250 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004251 raise Exception(
4252 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
4253 op
4254 )
4255 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004256
4257 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004258 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004259 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004260 raise Exception(
4261 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
4262 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004263
4264 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004265 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004266 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004267 raise Exception(
4268 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
4269 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004270
4271 # Put in default rank range, if missing
4272 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004273 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004274 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004275 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07004276
4277 # Tensor operator list
4278 # 'op': op name
4279 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08004280 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
4281 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07004282 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
4283 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08004284 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07004285
Kevin Cheng550ccc52021-03-03 11:21:43 -08004286 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
4287 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07004288
Kevin Cheng550ccc52021-03-03 11:21:43 -08004289 TYPE_BOOL = [DType.BOOL]
4290 TYPE_FI32 = [DType.FLOAT, DType.INT32]
4291 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
4292 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07004293
Kevin Cheng550ccc52021-03-03 11:21:43 -08004294 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07004295
Kevin Cheng1533b852021-09-01 12:51:58 -07004296 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07004297 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07004298 [DType.INT8, DType.INT8, DType.INT32],
4299 [DType.INT16, DType.INT8, DType.INT48],
4300 DType.FLOAT,
4301 ]
4302
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01004303 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07004304
4305 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08004306 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004307 "argmax": {
4308 "op": Op.ARGMAX,
4309 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004310 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004311 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4312 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004313 "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evArgmaxOutputRankMismatch,
4314 TosaErrorValidator.evArgmaxOutputShapeMismatch, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
4315 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004316 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004317 "avg_pool2d": {
4318 "op": Op.AVG_POOL2D,
4319 "operands": (1, 0),
4320 "rank": (4, 4),
4321 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
4322 "qgen": TosaQuantGen.qgUnary,
4323 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004324 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
4325 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
4326 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
4327 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
4328 TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08004329 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004330 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08004331 "conv2d_TEMPLATE": {
4332 "op": Op.CONV2D,
4333 "operands": (1, 2),
4334 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01004335 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004336 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004337 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01004338 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004339 "template": True,
4340 },
Kevin Cheng1533b852021-09-01 12:51:58 -07004341 # Templated operator. Filled in by createDynamicOpLists
4342 "conv3d_TEMPLATE": {
4343 "op": Op.CONV3D,
4344 "operands": (1, 2),
4345 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01004346 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07004347 "qgen": TosaQuantGen.qgConv,
4348 "types": TYPE_CONV,
4349 "template": True,
4350 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004351 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08004352 "depthwise_conv2d_TEMPLATE": {
4353 "op": Op.DEPTHWISE_CONV2D,
4354 "operands": (1, 2),
4355 "filter": [1, 1],
4356 "rank": (4, 4),
4357 "build_fcn": (
4358 build_depthwise_conv2d,
4359 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01004360 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004361 ),
4362 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004363 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01004364 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004365 "template": True,
4366 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004367 "fully_connected": {
4368 "op": Op.FULLY_CONNECTED,
4369 "operands": (1, 2),
4370 "rank": (2, 2),
4371 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
4372 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004373 "types": TYPE_CONV,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004374 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evWrongRank,
4375 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004376 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004377 "matmul": {
4378 "op": Op.MATMUL,
4379 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07004380 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08004381 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
4382 "qgen": TosaQuantGen.qgMatmul,
4383 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004384 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
4385 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004386 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004387 "max_pool2d": {
4388 "op": Op.MAX_POOL2D,
4389 "operands": (1, 0),
4390 "rank": (4, 4),
4391 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
4392 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004393 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
4394 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
4395 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
4396 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08004397 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004398 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08004399 "transpose_conv2d_TEMPLATE": {
4400 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07004401 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004402 "rank": (4, 4),
4403 "build_fcn": (
4404 build_transpose_conv2d,
4405 TosaTensorGen.tgTransposeConv2D,
4406 TosaArgGen.agTransposeConv2D,
4407 ),
4408 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004409 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01004410 "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004411 "template": True,
4412 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004413 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08004414 "clamp": {
4415 "op": Op.CLAMP,
4416 "operands": (1, 0),
4417 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
4418 "types": TYPE_NARROW_INT_FP,
4419 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004420 "sigmoid": {
4421 "op": Op.SIGMOID,
4422 "operands": (1, 0),
4423 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
4424 "types": TYPE_FP,
4425 },
4426 "tanh": {
4427 "op": Op.TANH,
4428 "operands": (1, 0),
4429 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
4430 "types": TYPE_FP,
4431 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004432 # Elementwise Binary Operators
4433 "add": {
4434 "op": Op.ADD,
4435 "operands": (2, 0),
4436 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4437 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004438 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4439 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004440 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004441 "arithmetic_right_shift": {
4442 "op": Op.ARITHMETIC_RIGHT_SHIFT,
4443 "operands": (2, 0),
4444 "build_fcn": (
4445 build_arithmetic_right_shift,
4446 TosaTensorGen.tgBroadcastFuzz,
4447 TosaArgGen.agArithmeticRightShift,
4448 ),
4449 "types": TYPE_INT,
4450 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004451 "bitwise_and": {
4452 "op": Op.BITWISE_AND,
4453 "operands": (2, 0),
4454 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4455 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004456 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4457 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004458 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004459 "bitwise_or": {
4460 "op": Op.BITWISE_OR,
4461 "operands": (2, 0),
4462 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4463 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004464 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4465 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004466 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004467 "bitwise_xor": {
4468 "op": Op.BITWISE_XOR,
4469 "operands": (2, 0),
4470 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4471 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004472 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4473 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004474 },
Matthew Haddon459443c2021-08-23 16:43:13 +01004475 "intdiv": {
4476 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004477 "operands": (2, 0),
4478 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4479 "types": [DType.INT32],
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004480 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4481 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004482 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004483 "logical_and": {
4484 "op": Op.LOGICAL_AND,
4485 "operands": (2, 0),
4486 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4487 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004488 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4489 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004490 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004491 "logical_left_shift": {
4492 "op": Op.LOGICAL_LEFT_SHIFT,
4493 "operands": (2, 0),
4494 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4495 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004496 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4497 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004498 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004499 "logical_right_shift": {
4500 "op": Op.LOGICAL_RIGHT_SHIFT,
4501 "operands": (2, 0),
4502 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4503 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004504 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4505 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004506 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004507 "logical_or": {
4508 "op": Op.LOGICAL_OR,
4509 "operands": (2, 0),
4510 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4511 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004512 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4513 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004514 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004515 "logical_xor": {
4516 "op": Op.LOGICAL_XOR,
4517 "operands": (2, 0),
4518 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4519 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004520 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4521 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004522 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004523 "maximum": {
4524 "op": Op.MAXIMUM,
4525 "operands": (2, 0),
4526 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4527 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004528 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4529 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004530 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004531 "minimum": {
4532 "op": Op.MINIMUM,
4533 "operands": (2, 0),
4534 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4535 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004536 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4537 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004538 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004539 "mul": {
4540 "op": Op.MUL,
4541 "operands": (2, 0),
4542 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
4543 "types": TYPE_INT_FP,
4544 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004545 "pow": {
4546 "op": Op.POW,
4547 "operands": (2, 0),
4548 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
4549 "types": TYPE_FP,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004550 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4551 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004552 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004553 "sub": {
4554 "op": Op.SUB,
4555 "operands": (2, 0),
4556 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4557 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004558 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4559 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004560 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004561 "table": {
4562 "op": Op.TABLE,
4563 # Use the automatic generation functions to create the input array
4564 # but create the table tensor in the build function, as it may be
4565 # a different type from the input
4566 "operands": (1, 0),
4567 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004568 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08004569 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004570 # Elementwise Unary operators
4571 "abs": {
4572 "op": Op.ABS,
4573 "operands": (1, 0),
4574 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4575 "types": TYPE_FI32,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004576 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4577 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004578 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004579 "bitwise_not": {
4580 "op": Op.BITWISE_NOT,
4581 "operands": (1, 0),
4582 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4583 "types": TYPE_INT,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004584 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4585 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004586 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004587 "ceil": {
4588 "op": Op.CEIL,
4589 "operands": (1, 0),
4590 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4591 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004592 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4593 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004594 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004595 "clz": {
4596 "op": Op.CLZ,
4597 "operands": (1, 0),
4598 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4599 "types": [DType.INT32],
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004600 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4601 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004602 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004603 "exp": {
4604 "op": Op.EXP,
4605 "operands": (1, 0),
4606 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4607 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004608 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4609 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004610 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004611 "floor": {
4612 "op": Op.FLOOR,
4613 "operands": (1, 0),
4614 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4615 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004616 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4617 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004618 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004619 "log": {
4620 "op": Op.LOG,
4621 "operands": (1, 0),
4622 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4623 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004624 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4625 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004626 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004627 "logical_not": {
4628 "op": Op.LOGICAL_NOT,
4629 "operands": (1, 0),
4630 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4631 "types": TYPE_BOOL,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004632 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4633 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004634 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004635 "negate": {
4636 "op": Op.NEGATE,
4637 "operands": (1, 0),
4638 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4639 "qgen": TosaQuantGen.qgUnary,
4640 "types": TYPE_INT_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004641 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
4642 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
4643 TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004644 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004645 "reciprocal": {
4646 "op": Op.RECIPROCAL,
4647 "operands": (1, 0),
4648 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4649 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004650 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4651 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004652 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004653 "rsqrt": {
4654 "op": Op.RSQRT,
4655 "operands": (1, 0),
4656 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4657 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004658 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4659 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004660 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004661 # Elementwise Ternary operators
4662 "select": {
4663 "op": Op.SELECT,
4664 "operands": (3, 0),
4665 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
4666 "types": TYPE_FIB,
4667 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004668 # Comparison operators
4669 "equal": {
4670 "op": Op.EQUAL,
4671 "operands": (2, 0),
4672 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4673 "types": TYPE_FI32,
4674 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004675 "greater_equal": {
4676 "op": Op.GREATER_EQUAL,
4677 "operands": (2, 0),
4678 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4679 "types": TYPE_FI32,
4680 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004681 "greater": {
4682 "op": Op.GREATER,
4683 "operands": (2, 0),
4684 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4685 "types": TYPE_FI32,
4686 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004687 # Reduction operators
4688 "reduce_all": {
4689 "op": Op.REDUCE_ALL,
4690 "operands": (1, 0),
4691 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4692 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004693 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4694 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4695 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004696 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004697 "reduce_any": {
4698 "op": Op.REDUCE_ANY,
4699 "operands": (1, 0),
4700 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4701 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004702 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4703 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4704 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004705 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004706 "reduce_max": {
4707 "op": Op.REDUCE_MAX,
4708 "operands": (1, 0),
4709 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4710 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004711 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4712 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4713 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004714 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004715 "reduce_min": {
4716 "op": Op.REDUCE_MAX,
4717 "operands": (1, 0),
4718 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4719 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004720 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4721 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4722 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004723 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004724 "reduce_product": {
4725 "op": Op.REDUCE_PRODUCT,
4726 "operands": (1, 0),
4727 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4728 "types": TYPE_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004729 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4730 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4731 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004732 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004733 "reduce_sum": {
4734 "op": Op.REDUCE_SUM,
4735 "operands": (1, 0),
4736 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4737 "types": TYPE_FI32,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004738 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4739 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4740 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004741 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004742 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004743 "concat": {
4744 "op": Op.CONCAT,
4745 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01004746 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004747 "types": TYPE_FIB,
4748 },
4749 "pad": {
4750 "op": Op.PAD,
4751 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01004752 "rank": (1, 5),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004753 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
4754 "qgen": TosaQuantGen.qgPad,
4755 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004756 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evPadSmallerZero,
4757 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004758 },
4759 "reshape": {
4760 "op": Op.RESHAPE,
4761 "operands": (1, 0),
4762 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
4763 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004764 "error_if_validators": (TosaErrorValidator.evTensorSizeInputOutputMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4765 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004766 },
4767 "reverse": {
4768 "op": Op.REVERSE,
4769 "operands": (1, 0),
4770 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4771 "types": TYPE_FIB,
4772 },
4773 "slice": {
4774 "op": Op.SLICE,
4775 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01004776 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004777 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
4778 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004779 "error_if_validators": (TosaErrorValidator.evStartSmallerZero, TosaErrorValidator.evSizeSmallerEqualZero, TosaErrorValidator.evStartSizeOutsideBounds,
4780 TosaErrorValidator.evSizeOutputShapeMismatch, TosaErrorValidator.evInputSizeStartLengthMismatch, TosaErrorValidator.evWrongRank,
4781 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004782 },
4783 "tile": {
4784 "op": Op.TILE,
4785 "operands": (1, 0),
4786 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
4787 "types": TYPE_FIB,
4788 },
4789 "transpose": {
4790 "op": Op.TRANSPOSE,
4791 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01004792 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004793 "build_fcn": (
4794 build_transpose,
4795 TosaTensorGen.tgBasic,
4796 TosaArgGen.agTranspose,
4797 ),
4798 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004799 "error_if_validators": (TosaErrorValidator.evIndexOutsideBounds, TosaErrorValidator.evIndexUsedTwice, TosaErrorValidator.evWrongRank,
4800 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004801 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004802 # Data nodes
4803 "const": {
4804 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004805 "operands": (0, 1),
4806 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08004807 "types": TYPE_FIB,
4808 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004809 "identity": {
4810 "op": Op.IDENTITY,
4811 "operands": (1, 0),
4812 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4813 "types": TYPE_FIB,
4814 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004815 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004816 "gather": {
4817 "op": Op.GATHER,
4818 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4819 "operands": (1, 0),
4820 "rank": (3, 3),
4821 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
4822 "types": TYPE_INT_FP,
4823 },
4824 "scatter": {
4825 "op": Op.SCATTER,
4826 # Only specify 'values_in' tensor here.
4827 #'indices' and 'input' are generated in op building stage
4828 "operands": (2, 0),
4829 "rank": (3, 3),
4830 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
4831 "types": TYPE_INT_FP,
4832 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004833 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004834 "resize": {
4835 "op": Op.RESIZE,
4836 "operands": (1, 0),
4837 "rank": (4, 4),
4838 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
4839 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01004840 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
4841 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
4842 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
Matthew Haddon848efb42021-09-09 12:30:53 +01004843 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004844 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
4845 TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004846 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004847 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004848 "cast": {
4849 "op": Op.CAST,
4850 "operands": (1, 0),
4851 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
4852 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
4853 },
4854 "rescale": {
4855 "op": Op.RESCALE,
4856 "operands": (1, 0),
Matthew Haddonc2025212021-10-08 21:21:05 +01004857 "rank": (1,4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004858 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004859 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Matthew Haddonc2025212021-10-08 21:21:05 +01004860 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, TosaErrorValidator.evScaleTrue,
4861 TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4862 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004863 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004864 # Custom
4865 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004866 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004867 # Two varients of cond_if, one that generates one of two constant tensors (no
4868 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4869 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004870 "cond_if_const": {
4871 "op": Op.COND_IF,
4872 "operands": (0, 2),
4873 "build_fcn": (
4874 build_cond_if_const,
4875 TosaTensorGen.tgBasic,
4876 TosaArgGen.agCondIf,
4877 ),
4878 "types": [DType.BOOL],
4879 },
4880 "cond_if_binary": {
4881 "op": Op.COND_IF,
4882 "operands": (2, 0),
4883 "build_fcn": (
4884 build_cond_if_binary,
4885 TosaTensorGen.tgBasic,
4886 TosaArgGen.agCondIf,
4887 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004888 "types": TYPE_INT_FP,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004889 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004890 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004891 "while_loop": {
4892 "op": Op.WHILE_LOOP,
4893 "operands": (0, 1),
4894 "build_fcn": (
4895 build_while_loop,
4896 TosaTensorGen.tgBasic,
4897 TosaArgGen.agWhileLoop,
4898 ),
4899 "types": [DType.INT32],
4900 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004901 }
4902
Kevin Cheng550ccc52021-03-03 11:21:43 -08004903
Eric Kunzee5e26762020-10-13 16:11:07 -07004904class OutputShaper:
4905 # Methods in this class compute the expected output shape and datatype
4906 # for common classes of operations
4907 def __init__(self):
4908 pass
4909
4910 # These methods return arguments that can be used for
4911 # creating a new output tensor
4912 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004913 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4914 if error_name != ErrorIf.RankMismatch:
4915 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004916 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004917
4918 shape = []
4919 for i in range(len(a.shape)):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004920 if a.shape[i] == 1 and error_name == None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004921 shape.append(b.shape[i])
4922 else:
4923 shape.append(a.shape[i])
4924
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004925 if error_name == ErrorIf.WrongOutputType:
4926 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4927 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4928 outputDType = rng.choice(wrong_dtypes)
4929 else:
4930 outputDType = a.dtype
4931
4932 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004933
4934 @staticmethod
4935 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004936 assert len(a.shape) == len(b.shape)
4937 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004938
4939 shape = []
4940 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004941 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004942 shape.append(a.shape[i])
4943
Kevin Cheng550ccc52021-03-03 11:21:43 -08004944 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004945
4946 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004947 def unaryOp(ser, rng, a, error_name=None):
4948 if error_name == ErrorIf.WrongOutputType:
4949 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4950 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4951 outputDType = rng.choice(wrong_dtypes)
4952 else:
4953 outputDType = a.dtype
4954
4955 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004956
4957 @staticmethod
4958 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004959 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
4960 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004961
4962 shape = []
4963 for i in range(len(a.shape)):
4964 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4965
Kevin Cheng550ccc52021-03-03 11:21:43 -08004966 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004967
4968 @staticmethod
4969 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004970 assert len(a.shape) == len(b.shape)
4971 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004972
4973 # Do broadcast
4974 shape = []
4975 for i in range(len(a.shape)):
4976 if a.shape[i] == 1:
4977 shape.append(b.shape[i])
4978 else:
4979 shape.append(a.shape[i])
4980
4981 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08004982 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07004983
4984 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004985 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004986 shape = a.shape.copy()
Matthew Haddond6ce7252021-09-29 15:35:44 +01004987 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ShapeOfAxisNotOne]:
4988 shape[axis] = 1
4989 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4990 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004991
Matthew Haddond6ce7252021-09-29 15:35:44 +01004992 if error_name == ErrorIf.WrongOutputType:
4993 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4994 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4995 outputDType = rng.choice(wrong_dtypes)
4996 else:
4997 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004998
Matthew Haddond6ce7252021-09-29 15:35:44 +01004999 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005000
5001 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005002 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005003 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005004
5005 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5006 del shape[axis]
5007
5008 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5009 remove = rng.choice([True, False])
5010 if remove and len(shape) > 1:
5011 del shape[0]
5012 else:
5013 shape.append(1)
5014 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5015 for i in range(len(shape)):
5016 shape[i] = shape[i] + rng.integers(1, 10)
5017
5018 if error_name == ErrorIf.WrongOutputType:
5019 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5020 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5021 outputDType = rng.choice(wrong_dtypes)
5022 else:
5023 outputDType = DType.INT32
5024
5025 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005026
5027 @staticmethod
5028 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
5029
5030 # IFM: NHWC
5031 # Filter: OHWI
5032 # OFM: NHWC
5033
5034 if len(padding) == 2:
5035 # Expand padding to 4 parameters in the case of transpose_conv2d
5036 # From H,W to T,B,L,R
5037 padding = [padding[0], padding[0], padding[1], padding[1]]
5038
Kevin Cheng550ccc52021-03-03 11:21:43 -08005039 h = (
5040 ifm.shape[1]
5041 - filter.shape[1]
5042 - (filter.shape[1] - 1) * (dilations[0] - 1)
5043 + padding[0]
5044 + padding[1]
5045 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005046
Kevin Cheng550ccc52021-03-03 11:21:43 -08005047 w = (
5048 ifm.shape[2]
5049 - filter.shape[2]
5050 - (filter.shape[2] - 1) * (dilations[1] - 1)
5051 + padding[2]
5052 + padding[3]
5053 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005054
Eric Kunzee5e26762020-10-13 16:11:07 -07005055 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5056
Kevin Cheng3a478572021-01-22 17:21:02 -08005057 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005058 out_dtype = DType.INT32
5059 elif ifm.dtype == DType.INT16:
5060 out_dtype = DType.INT48
5061 elif ifm.dtype == DType.FLOAT:
5062 out_dtype = DType.FLOAT
5063 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005064 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005065
Kevin Cheng550ccc52021-03-03 11:21:43 -08005066 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005067
5068 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -07005069 def conv3dOp(ser, ifm, filter, strides, padding, dilations):
5070
5071 # IFM: NDHWC
5072 # Filter: ODHWI
5073 # OFM: NDHWC
5074
5075 d = (
5076 ifm.shape[1]
5077 - filter.shape[1]
5078 - (filter.shape[1] - 1) * (dilations[0] - 1)
5079 + padding[0]
5080 + padding[1]
5081 ) // strides[0] + 1
5082
5083 h = (
5084 ifm.shape[2]
5085 - filter.shape[2]
5086 - (filter.shape[2] - 1) * (dilations[1] - 1)
5087 + padding[2]
5088 + padding[3]
5089 ) // strides[1] + 1
5090
5091 w = (
5092 ifm.shape[3]
5093 - filter.shape[3]
5094 - (filter.shape[3] - 1) * (dilations[2] - 1)
5095 + padding[4]
5096 + padding[5]
5097 ) // strides[2] + 1
5098
5099 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5100
5101 if ifm.dtype == DType.INT8:
5102 out_dtype = DType.INT32
5103 elif ifm.dtype == DType.INT16:
5104 out_dtype = DType.INT48
5105 elif ifm.dtype == DType.FLOAT:
5106 out_dtype = DType.FLOAT
5107 else:
5108 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
5109
5110 return ser.addOutput(ofm_shape, out_dtype)
5111
5112 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -07005113 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
5114 # IFM: NHWC
5115 # Filter: HWCM
5116 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08005117 h = (
5118 ifm.shape[1]
5119 - filter.shape[0]
5120 - (filter.shape[0] - 1) * (dilations[0] - 1)
5121 + padding[0]
5122 + padding[1]
5123 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005124
Kevin Cheng550ccc52021-03-03 11:21:43 -08005125 w = (
5126 ifm.shape[2]
5127 - filter.shape[1]
5128 - (filter.shape[1] - 1) * (dilations[1] - 1)
5129 + padding[2]
5130 + padding[3]
5131 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005132
Eric Kunzee5e26762020-10-13 16:11:07 -07005133 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5134
Kevin Cheng3a478572021-01-22 17:21:02 -08005135 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005136 out_dtype = DType.INT32
5137 elif ifm.dtype == DType.INT16:
5138 out_dtype = DType.INT48
5139 elif ifm.dtype == DType.FLOAT:
5140 out_dtype = DType.FLOAT
5141 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005142 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005143
Kevin Cheng550ccc52021-03-03 11:21:43 -08005144 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005145
5146 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005147 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005148 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005149 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005150 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005151 h = 1
5152 w = 1
5153 else:
5154 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
5155 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
5156
5157 if error_name == ErrorIf.PoolingOutputShapeMismatch:
5158 choices = [1, 2, 3, 4, 5]
5159 h = h + rng.choice(choices)
5160 w = w + rng.choice(choices)
Eric Kunzee5e26762020-10-13 16:11:07 -07005161
Eric Kunzee5e26762020-10-13 16:11:07 -07005162 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005163
5164 if error_name == ErrorIf.WrongOutputType:
5165 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5166 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5167 outputDType = rng.choice(wrong_dtypes)
5168 else:
5169 outputDType = ifm.dtype
5170
5171 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005172
5173 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005174 def fullyConnectedOp(ser, rng, input, filter, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005175 # input: N, IC
5176 # filter: OC, IC
5177 # output: N, OC
5178
5179 output_shape = [input.shape[0], filter.shape[0]]
5180
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005181 if error_name == ErrorIf.WrongOutputType:
5182 if input.dtype == DType.INT8:
5183 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
5184 elif input.dtype == DType.INT16:
5185 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
5186 elif input.dtype == DType.FLOAT:
5187 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
5188 out_dtype = rng.choice(a=incorrect_types)
5189 elif input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005190 out_dtype = DType.INT32
5191 elif input.dtype == DType.INT16:
5192 out_dtype = DType.INT48
5193 elif input.dtype == DType.FLOAT:
5194 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005195 elif error_name == ErrorIf.WrongInputType:
5196 # Pick some potentially correct output dtype if input type is incorrect
5197 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005198 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005199 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005200
Kevin Cheng550ccc52021-03-03 11:21:43 -08005201 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005202
5203 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005204 def matmulOp(ser, rng, a, b, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005205 # a: N, H, C
5206 # b: N, C, W
5207 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005208
Kevin Cheng2d60f002021-06-09 14:18:32 -07005209 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005210
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005211 if error_name == ErrorIf.WrongOutputType:
5212 if a.dtype == DType.INT8:
5213 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
5214 elif a.dtype == DType.INT16:
5215 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
5216 elif a.dtype == DType.FLOAT:
5217 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
5218 out_dtype = rng.choice(a=incorrect_types)
5219 elif a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005220 out_dtype = DType.INT32
5221 elif a.dtype == DType.INT16:
5222 out_dtype = DType.INT48
5223 elif a.dtype == DType.FLOAT:
5224 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005225 elif error_name == ErrorIf.WrongInputType:
5226 # Pick some potentially correct output dtype if input type is incorrect
5227 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005228 else:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005229 raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005230
Kevin Cheng550ccc52021-03-03 11:21:43 -08005231 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005232
5233 @staticmethod
Matthew Haddon818ab902021-07-27 09:12:49 +01005234 def concatOp(ser, axis, *a):
5235 input1 = a[0]
5236 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005237
Matthew Haddon818ab902021-07-27 09:12:49 +01005238 output_shape = input1.shape.copy()
Eric Kunzee5e26762020-10-13 16:11:07 -07005239
Matthew Haddon818ab902021-07-27 09:12:49 +01005240 output_shape[axis] = input1.shape[axis]
5241
5242 for tensor in remaining_inputs:
5243 output_shape[axis] += tensor.shape[axis]
5244
5245 return ser.addOutput(output_shape, input1.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005246
5247 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005248 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005249
5250 output_shape = a.shape.copy()
5251
5252 for i in range(len(output_shape)):
5253 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5254
Matthew Haddone807aae2021-10-11 18:12:58 +01005255 # Fix negative output shape if error_if test causes it
5256 if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
5257 output_shape = [i if i >= 1 else 1 for i in output_shape]
5258
5259 if error_name == ErrorIf.WrongOutputType:
5260 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5261 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5262 outputDType = rng.choice(wrong_dtypes)
5263 else:
5264 outputDType = a.dtype
5265
5266 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005267
5268 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005269 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005270 output_shape = shape.copy()
5271
5272 totalElements = 1
5273 for i in a.shape:
5274 totalElements *= i
5275
5276 # If there are any -1 elements, figure out what that dimension must be
5277 totalOutputElements = 1
5278 for i in output_shape:
5279 if i != -1:
5280 totalOutputElements *= i
5281
5282 # And fill it in
5283 for i in range(len(output_shape)):
5284 if output_shape[i] == -1:
5285 output_shape[i] = totalElements // totalOutputElements
5286
Matthew Haddone807aae2021-10-11 18:12:58 +01005287 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5288 for i in range(len(output_shape)):
5289 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5290
5291 if error_name == ErrorIf.WrongOutputType:
5292 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5293 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5294 outputDType = rng.choice(wrong_dtypes)
5295 else:
5296 outputDType = a.dtype
5297
5298 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005299
5300 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005301 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005302
Matthew Haddone807aae2021-10-11 18:12:58 +01005303 if error_name == ErrorIf.WrongOutputType:
5304 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5305 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5306 outputDType = rng.choice(wrong_dtypes)
5307 else:
5308 outputDType = a.dtype
5309
5310 if error_name == ErrorIf.SizeOutputShapeMismatch:
5311 output_shape = size.copy()
5312 for index in range(len(output_shape)):
5313 if output_shape[index] <= 2:
5314 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5315 else:
5316 output_shape[index] = output_shape[index] + rng.choice([-2, -1, 1, 2])
5317 else:
5318 output_shape = size.copy()
5319
5320 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005321
5322 @staticmethod
5323 def tileOp(ser, a, multiples):
5324
5325 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005326 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005327
5328 for i in range(len(output_shape)):
5329 output_shape[i] = a.shape[i] * multiples[i]
5330
Kevin Cheng550ccc52021-03-03 11:21:43 -08005331 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005332
5333 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005334 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005335 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005336
Kevin Cheng550ccc52021-03-03 11:21:43 -08005337 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005338
Matthew Haddone807aae2021-10-11 18:12:58 +01005339 if error_name == ErrorIf.IndexOutsideBounds:
5340 for i in range(len(output_shape)):
5341 output_shape[i] = a.shape[0]
5342 else:
5343 for i in range(len(output_shape)):
5344 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005345
Matthew Haddone807aae2021-10-11 18:12:58 +01005346 if error_name == ErrorIf.WrongOutputType:
5347 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5348 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5349 outputDType = rng.choice(wrong_dtypes)
5350 else:
5351 outputDType = a.dtype
5352
5353 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005354
5355 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08005356 def gatherOp(ser, values, indices):
5357 assert len(values.shape) == 3
5358 assert len(indices.shape) == 2
5359 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005360
Kevin Cheng77d0f762020-11-24 10:26:32 -08005361 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5362
Kevin Cheng550ccc52021-03-03 11:21:43 -08005363 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005364
5365 @staticmethod
5366 def scatterOp(ser, values_in, indices, input):
5367 assert len(values_in.shape) == 3
5368 assert len(indices.shape) == 2
5369 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005370 assert values_in.shape[0] == indices.shape[0] # N
5371 assert input.shape[1] == indices.shape[1] # W
5372 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005373
5374 output_shape = values_in.shape
5375
Kevin Cheng550ccc52021-03-03 11:21:43 -08005376 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005377
5378 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005379 def tableOp(ser, input, table_dtype):
5380 # Same shape as the input, but dtype dependent on table dtype
5381 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
5382 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
5383 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005384
5385 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005386 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005387 serializer,
5388 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005389 input,
5390 mode,
5391 stride,
5392 offset,
5393 shift,
5394 stride_fp,
5395 offset_fp,
5396 output_dims,
5397 input_dtype,
5398 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01005399 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08005400 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01005401 if error_name == ErrorIf.WrongRank:
5402 output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
5403 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005404 if error_name == ErrorIf.BatchMismatch:
5405 output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
5406 elif error_name == ErrorIf.ChannelMismatch:
5407 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
5408 else:
5409 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005410
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005411 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005412
5413 @staticmethod
5414 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005415 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005416
5417 @staticmethod
5418 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08005419 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005420 out_dtype = DType.INT32
5421 elif ifm.dtype == DType.INT16:
5422 out_dtype = DType.INT48
5423 elif ifm.dtype == DType.FLOAT:
5424 out_dtype = DType.FLOAT
5425 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005426 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005427
Kevin Cheng550ccc52021-03-03 11:21:43 -08005428 return ser.addOutput(output_shape, out_dtype)