blob: 1f35b8b092b7451f98708c2f9ab0bc48dca268cf [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
48DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080049Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070050ResizeMode = tosa.ResizeMode.ResizeMode()
51
52class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
54
Eric Kunzee5e26762020-10-13 16:11:07 -070055 def __init__(self):
56 pass
57
58 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010059 def getQinfo(testGen, dtype, error_name=None):
60
Les Bell30e46802021-07-23 09:43:31 +010061 if dtype == DType.INT8:
62 return testGen.randInt(-128, 128)
Matthew Haddone4ecdb22021-09-28 11:38:21 +010063 elif dtype == DType.UINT8:
Les Bell30e46802021-07-23 09:43:31 +010064 return testGen.randInt(0, 256)
Matthew Haddonc4cf0372021-10-11 09:38:10 +010065 elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.WeightZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
Matthew Haddone4ecdb22021-09-28 11:38:21 +010066 zero_point = testGen.randInt(-128, 128)
67 if zero_point == 0:
68 zero_point = 1
69 return zero_point
Les Bell30e46802021-07-23 09:43:31 +010070 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070071
72 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010073 def qgUnary(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070074 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone4ecdb22021-09-28 11:38:21 +010075 if error_name == ErrorIf.InputZeroPointNotZero:
76 qinfo.UnaryQuantInfo(
77 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype)
78 )
79 elif error_name == ErrorIf.OutputZeroPointNotZero:
80 qinfo.UnaryQuantInfo(
81 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype, error_name)
82 )
83 else:
84 qinfo.UnaryQuantInfo(
85 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
86 )
Eric Kunzee5e26762020-10-13 16:11:07 -070087 return qinfo
88
89 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010090 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070091 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010092 if isinstance(dtype_or_dtypeList, list):
93 # a list of [input, weights, accumulator] dtypes
94 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070095 else:
Les Bell30e46802021-07-23 09:43:31 +010096 # an int, [input, weights, accumulator] dtypes are the same
97 dtypeList = [dtype_or_dtypeList] * 3
Matthew Haddonc4cf0372021-10-11 09:38:10 +010098
99 if error_name == ErrorIf.InputZeroPointNotZero:
100 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
101 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
102 elif error_name == ErrorIf.WeightZeroPointNotZero:
103 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
104 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
105 else:
106 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
107 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
108
Les Bell30e46802021-07-23 09:43:31 +0100109 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -0700110 return qinfo
111
112 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100113 def qgMatmul(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700114 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100115 if error_name == ErrorIf.InputZeroPointNotZero:
116 qinfo.MatMulQuantInfo(
117 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype, error_name)
Kevin Chengacb550f2021-06-29 15:32:19 -0700118 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100119 else:
120 qinfo.MatMulQuantInfo(
121 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
122 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700123 return qinfo
124
125 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100126 def qgPad(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700127 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone807aae2021-10-11 18:12:58 +0100128 if error_name == ErrorIf.InputZeroPointNotZero:
129 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name))
130 else:
131 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700132 return qinfo
133
134 @staticmethod
135 def computeMultiplierAndShift(scaleFp, scale32):
136 # Derived from computeMultiplierAndShiftTosaScale32
137 # Provide a floating-point scaling factor and the scale32 parameter
138 # to compute the multiplier and shift
139
140 if scale32:
141 scaleBits = 31
142 else:
143 scaleBits = 15
144
145 m, shift = math.frexp(scaleFp)
146
147 if scaleFp < 0.0:
148 m = -m
149
150 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800151 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700152
153 if multiplier == (1 << scaleBits):
154 multiplier = multiplier // 2
155 shift = shift + 1
156
157 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100158 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
159
160 # Adjust multiplier such that shift is in allowed value range.
161 if shift == 0:
162 multiplier = multiplier // 4
163 shift = shift + 2
164 elif shift == 1:
165 multiplier = multiplier // 2
166 shift = shift + 1
167 elif shift == 63:
168 multiplier = multiplier * 2
169 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700170
Kevin Cheng550ccc52021-03-03 11:21:43 -0800171 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100172 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700173
174 return multiplier, shift
175
176
Kevin Cheng550ccc52021-03-03 11:21:43 -0800177class TosaTensorGen:
178 """Tensor generators create a shape list for the placeholder and const tensor
179 data operands for the operator. The actual random data is generated separately for each test."""
180
Eric Kunzee5e26762020-10-13 16:11:07 -0700181 def __init__(self):
182 pass
183
184 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100185 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800186 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700187 shape = testGen.makeShape(rank)
188
Matthew Haddonc2025212021-10-08 21:21:05 +0100189 # Constrict dimension size for large ranks when creating WrongRank tests
190 shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 shape_list = []
193 for i in range(pl + const):
194 shape_list.append(shape.copy())
195
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100196 if error_name == ErrorIf.RankMismatch:
197 if rank == 1 and i != 1:
198 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
199 elif i != 1:
200 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
201
Eric Kunzee5e26762020-10-13 16:11:07 -0700202 return shape_list
203
204 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100205 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800206 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700207
Matthew Haddon848efb42021-09-09 12:30:53 +0100208 if error_name != ErrorIf.WrongRank:
209 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700210
211 shape = testGen.makeShape(rank)
212
213 # Constrict the batch size?
214 if testGen.args.max_batch_size:
215 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100216
217 # Constrict dimension size for large ranks when creating WrongRank tests
218 shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700219
220 shape_list = []
221 for i in range(pl + const):
222 shape_list.append(shape.copy())
223
224 return shape_list
225
226 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100227 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800228 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800229
Kevin Cheng550ccc52021-03-03 11:21:43 -0800230 assert pl == 2
231 assert const == 0
232 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800233
234 values_in_shape = testGen.makeShape(rank)
235
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100236 # ignore max batch size if target shape is set
237 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800238 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
239
Kevin Cheng550ccc52021-03-03 11:21:43 -0800240 W = testGen.randInt(
241 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
242 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100243 # Constrict W if one dimension is too large to keep tensor size reasonable
244 if max(values_in_shape) > 5000:
245 W = testGen.randInt(0, 16)
246
Kevin Cheng77d0f762020-11-24 10:26:32 -0800247 input_shape = [values_in_shape[0], W, values_in_shape[2]]
248
249 shape_list = []
250 shape_list.append(values_in_shape.copy())
251 shape_list.append(input_shape.copy())
252
253 return shape_list
254
255 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100256 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700257 shape = testGen.makeShape(rank)
258
Kevin Cheng550ccc52021-03-03 11:21:43 -0800259 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700260
261 shape_list = []
262
263 # Choose one of the inputs to broadcast
264 bcast_idx = testGen.randInt(0, pl + const)
265 for i in range(pl + const):
266 shape_bcast = shape.copy()
267
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100268 if error_name == ErrorIf.RankMismatch:
269 bcast_idx = -1 # Turn off broadcast because we are not testing it
270 if rank == 1 and i != 1:
271 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
272 elif i != 1:
273 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
274
Eric Kunzee5e26762020-10-13 16:11:07 -0700275 # If the chosen input, pick a random index to broadcast
276 if i == bcast_idx:
277 fuzz_idx = testGen.randInt(0, rank)
278 shape_bcast[fuzz_idx] = 1
279
280 shape_list.append(shape_bcast)
281
282 return shape_list
283
284 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100285 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800286 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700287
Kevin Cheng550ccc52021-03-03 11:21:43 -0800288 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700289
290 # IFM dimensions are NHWC
291 ifm_shape = testGen.makeShape(rank)
292
293 # Constrict the batch size?
294 if testGen.args.max_batch_size:
295 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
296
297 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800298 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700299
300 # Generate a random OFM depth
301 ofm_depth = testGen.makeShape(1)[0]
302
303 # The filter dimensions are OHWI
304 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
305
306 # The bias is OC
307 bias_shape = np.asarray([ofm_depth])
308
309 return [ifm_shape, filter_shape, bias_shape]
310
311 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100312 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700313 pl, const = op["operands"]
314
315 assert rank == 5
316
317 # IFM dimensions are NDHWC
318 ifm_shape = testGen.makeShape(rank)
319
320 # Constrict the batch size?
321 if testGen.args.max_batch_size:
322 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
323
324 # Get the filter depth/height/width from the operator parameters
325 filter_dhw = op["filter"]
326
327 # Generate a random OFM channel
328 ofm_channel = testGen.makeShape(1)[0]
329
330 # The filter dimensions are ODHWI
331 filter_shape = np.asarray(
332 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
333 )
334
335 # The bias is OC
336 bias_shape = np.asarray([ofm_channel])
337
338 return [ifm_shape, filter_shape, bias_shape]
339
340 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100341 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800342 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700343
Kevin Cheng550ccc52021-03-03 11:21:43 -0800344 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700345
346 # IFM dimensions are NHWC
347 ifm_shape = testGen.makeShape(rank)
348
349 # Constrict the batch size?
350 if testGen.args.max_batch_size:
351 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
352
353 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800354 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700355
356 # Generate a random OFM depth
357 ofm_depth = testGen.makeShape(1)[0]
358
359 # The filter dimensions are OHWI
360 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
361
Kevin Cheng989cb052021-04-28 16:29:44 -0700362 # The bias is OC
363 bias_shape = np.asarray([ofm_depth])
364
365 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700366
367 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100368 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800369 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700370
Kevin Cheng550ccc52021-03-03 11:21:43 -0800371 assert rank == 4
372 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700373
374 # IFM dimensions are NHWC
375 ifm_shape = testGen.makeShape(rank)
376
377 # Constrict the batch size?
378 if testGen.args.max_batch_size:
379 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
380
381 # Get the filter height/width from the operator parameters
382 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800383 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700384
385 # Generate a random OFM depth, but don't let it get too big because
386 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800387 filter_m = (
388 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
389 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700390
391 # The filter dimensions are HWCM
392 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
393
394 # The bias is M * C
395 bias_shape = np.asarray([ifm_shape[3] * filter_m])
396
397 return [ifm_shape, filter_shape, bias_shape]
398
399 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100400 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800401 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700402
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100403 if error_name != ErrorIf.WrongRank:
404 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700405
406 input_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100407
408 # Constrict dimension size for large ranks when creating WrongRank tests
409 shape = TosaErrorIfArgGen.eiRestrictDimension(input_shape, error_name)
410
Kevin Chengacb550f2021-06-29 15:32:19 -0700411 filter_oc = testGen.rng.integers(
412 low=testGen.args.tensor_shape_range[0],
413 high=testGen.args.tensor_shape_range[1],
414 size=1,
415 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700416 filter_shape = np.asarray([filter_oc, input_shape[1]])
417
418 bias_shape = np.asarray([filter_oc])
419
420 return [input_shape, filter_shape, bias_shape]
421
422 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100423 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800424 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700425
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100426 if error_name != ErrorIf.WrongRank:
427 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800428 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700429
430 a_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100431
432 # Constrict dimension size for large ranks when creating WrongRank tests
433 shape = TosaErrorIfArgGen.eiRestrictDimension(a_shape, error_name)
434
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100435 # Get a random number for b_oc even if target shape is defined
436 b_oc = np.int32(
437 testGen.rng.integers(
438 low=testGen.args.tensor_shape_range[0],
439 high=testGen.args.tensor_shape_range[1],
440 size=1,
441 )
442 )[0]
443 # If N or H is large let b_oc be 1 to reduce output tensor size
444 if max(a_shape) > 1000:
445 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700446
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100447 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700448 return [a_shape, b_shape]
449
Matthew Haddon818ab902021-07-27 09:12:49 +0100450 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100451 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100452 pl, const = opName["operands"]
453 shape = testGen.makeShape(rank)
454
455 # Create extra tensors to concat.
456 # Take into account value of pl when getting maximum number of concats
457 num_tensors = testGen.randInt(0, 4)
458 shape_list = []
459 for i in range(pl + const + num_tensors):
460 shape_list.append(shape.copy())
461
462 return shape_list
463
464 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100465 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100466 # Split concat shape along axis to allow for multiple const inputs
467 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100468 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddon818ab902021-07-27 09:12:49 +0100469 return shapeList
470
Jeremy Johnson960985a2021-10-06 10:58:14 +0100471 # Create copy of shape we are going to split (so we don't alter shapeList)
472 shape = shapeList[0].copy()
473 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100474 new_shapeList = [shape.copy()]
475 length_on_axis = shape[axis]
476 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700477 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100478 # Calculate split on axis and remaining value
479 split_shape_val = int(shape[axis] / 2)
480 remaining_length = remaining_length - split_shape_val
481
482 # Append new shape, and set remaining shape
483 shape[axis] = split_shape_val
484 new_shapeList.append(shape.copy())
485 shape[axis] = remaining_length
486 if i == len(shapeList) - 3:
487 new_shapeList.append(shape.copy())
488
489 return new_shapeList
490
491
Eric Kunzee5e26762020-10-13 16:11:07 -0700492class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800493 """Argument generators create exhaustive or random lists of attributes for operators that take
494 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
495 tuples where the descriptive_name is appended to the test name and the arglist is expanded
496 as arguments to the operator build function."""
497
Eric Kunzee5e26762020-10-13 16:11:07 -0700498 def __init__(self):
499 pass
500
501 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100502 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800503 """A trivial argument generator for operators that don't take any
504 non-tensor arguments"""
505 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700506
507 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100508 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800509 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700510 axes = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700511 shape = shapeList[0]
512
Matthew Haddond6ce7252021-09-29 15:35:44 +0100513 if error_name == ErrorIf.AxisSmallerZero:
514 small_axis = testGen.rng.integers(-5, 0)
515 axes.append(("axis{}".format(small_axis), [small_axis]))
516 elif error_name == ErrorIf.AxisLargerRank:
517 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
518 axes.append(("axis{}".format(large_axis), [large_axis]))
519 else:
520 for a in range(0, len(shape)):
521 axes.append(("axis{}".format(a), [a]))
522
Eric Kunzee5e26762020-10-13 16:11:07 -0700523 return axes
524
525 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100526 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700527 arg_list = []
528
529 ifm_shape = shapeList[0]
530 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100531 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
532 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700533
Les Bell7aa69f42021-09-20 10:44:07 +0100534 # Check the rank
535 rank = 5 if opName.startswith("conv3d") else 4
536 assert len(ifm_shape) == rank
537 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700538
Les Bell7aa69f42021-09-20 10:44:07 +0100539 # kernel rank omits batch and channels
540 k_rank = rank - 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700541
Les Bell7aa69f42021-09-20 10:44:07 +0100542 # Generate comprehensive argument lists
543 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
544 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
545 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
546 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
547 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
548 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700549
Les Bell7aa69f42021-09-20 10:44:07 +0100550 # add some oversize argument values
551 if max(ifm_shape) < 64:
552 bigPadding = 9
553 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
554 bigStride = 8
555 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
556 bigDilation = 7
557 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100558
559 # There are too many parameter combinations, so generate them sparsely
Les Bell7aa69f42021-09-20 10:44:07 +0100560 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
561 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
562 if sparsity < 13:
563 sparsity = 1
564 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
565 sparsity += 1
Les Bellf414b3c2021-09-06 11:29:46 +0100566 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100567 for s in sorted(list(strides)):
568 for p in sorted(list(paddings)):
569 for d in sorted(list(dilations)):
570 if (n % sparsity == 0
571 # padding must not exceed the kernel size ?
572 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
573 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
574 # the padded shape must exceed the kernel size
575 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
576 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
577 # the padded shape must exceed the dilation
578 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
579 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
580 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100581 arg_list.append(
582 (
583 "st{}_pad{}_dilat{}".format(
584 "".join([str(x) for x in s]),
585 "".join([str(x) for x in p]),
586 "".join([str(x) for x in d]),
587 ),
588 [s, p, d],
589 )
590 )
591 n += 1
592
Kevin Cheng1533b852021-09-01 12:51:58 -0700593 return arg_list
594
595 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100596 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700597 arg_list = []
598
599 ifm_shape = shapeList[0]
600 filter_shape = shapeList[1]
601
602 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800603 assert len(ifm_shape) == 4
604 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700605
Les Bell7aa69f42021-09-20 10:44:07 +0100606 # Generate comprehensive argument lists
607 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
608 paddings = {x for x in itertools.product(*([p_vals] * 2))}
609 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
610 strides = {x for x in itertools.product(*([s_vals] * 2))}
611 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
612 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700613
Les Bell7aa69f42021-09-20 10:44:07 +0100614 # add some oversize argument values
615 if max(ifm_shape) < 64:
616 bigPadding = 9
617 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
618 bigStride = 8
619 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
620 bigDilation = 7
621 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700622
Les Bell7aa69f42021-09-20 10:44:07 +0100623 # There are too many parameter combinations, so generate them sparsely
624 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
625 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
626 if sparsity < 13:
627 sparsity = 1
628 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
629 sparsity += 1
630 n = 0
631 for s in sorted(list(strides)):
632 for p in sorted(list(paddings)):
633 for d in sorted(list(dilations)):
634 if n % sparsity == 0:
635 # Determine the output shape
636 oh = (
637 ifm_shape[1]
638 - filter_shape[1]
639 - (filter_shape[1] - 1) * (d[0] - 1)
640 + 2 * p[0]
641 ) // s[0] + 1
642 ow = (
643 ifm_shape[2]
644 - filter_shape[2]
645 - (filter_shape[2] - 1) * (d[1] - 1)
646 + 2 * p[1]
647 ) // s[1] + 1
648 os = [ifm_shape[0], oh, ow, filter_shape[0]]
649 arg_list.append(
650 (
651 "st{}_pad{}_dilat{}_os{}".format(
652 "".join([str(x) for x in s]),
653 "".join([str(x) for x in p]),
654 "".join([str(x) for x in d]),
655 "x".join([str(x) for x in os]),
656 ),
657 [s, p, d, os],
658 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800659 )
Les Bell7aa69f42021-09-20 10:44:07 +0100660 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700661
662 return arg_list
663
664 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100665 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700666 arg_list = []
667 rank = len(shapeList[0])
668
Les Bell7ffccce2021-07-28 15:37:02 +0100669 # Exhaustively test combinations of padding on each side of each dimension
670 # - the range of padding values is defined by pad_min and pad_max
671 # - for padding >9, the name format needs to be more distinctive
672 pad_min, pad_max = 0, 1
673 pad_values = [x for x in range(pad_min, pad_max + 1)]
Matthew Haddone807aae2021-10-11 18:12:58 +0100674 if error_name == ErrorIf.PadSmallerZero:
675 pad_values = [x for x in range(-2, 0)]
Les Bell7ffccce2021-07-28 15:37:02 +0100676 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
677 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700678
Les Bell7ffccce2021-07-28 15:37:02 +0100679 for paddings in shape_pad_values:
680 name = "pad"
681 for r in range(rank):
682 before, after = paddings[r]
683 name = f"{name}{before}{after}"
684 arg_list.append((name, [np.array(paddings)]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700685
686 return arg_list
687
688 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100689 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700690 arg_list = []
691
692 shape = shapeList[0]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100693 if error_name != ErrorIf.WrongRank:
694 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700695
Les Bell7aa69f42021-09-20 10:44:07 +0100696 # Generate comprehensive argument lists
697 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
698 paddings = {x for x in itertools.product(*([p_vals] * 4))}
699 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
700 strides = {x for x in itertools.product(*([s_vals] * 2))}
701 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
702 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700703
Les Bell7aa69f42021-09-20 10:44:07 +0100704 # add some oversize argument values
705 bigStride = 7
706 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
707 bigKernel = 6
708 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
709 if max(shape) < 64:
710 # padding must be less than the kernel size
711 bigPadding = bigKernel - 1
712 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700713
Les Bell7aa69f42021-09-20 10:44:07 +0100714 # There are too many parameter combinations, so generate them sparsely
715 sparsity = len(paddings) * len(strides) * len(kernels) // 500 + 1
716 n = 0
717 for s in sorted(list(strides)):
718 for p in sorted(list(paddings)):
719 for k in sorted(list(kernels)):
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100720 # Calculate output height to test for error_if conditions
721 oh = (shape[1] + p[0] + p[1] + s[0] - k[0]) // s[0]
722 ow = (shape[2] + p[2] + p[3] + s[1] - k[1]) // s[1]
723 y = (oh * s[0]) - p[0] - p[1] - s[0] + k[0]
724 x = (ow * s[1]) - p[2] - p[3] - s[1] + k[1]
725
726 if error_name in [ErrorIf.StrideSmallerOne, ErrorIf.KernelSmallerOne, ErrorIf.PadSmallerZero, ErrorIf.PadLargerEqualKernel]:
727 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(testGen, error_name, s, p, k)
728 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
729 arg_list.append(
730 (
731 "st{}_kern{}_pad{}".format(
732 "".join([str(x) for x in sNew]),
733 "".join([str(x) for x in kNew]),
734 "".join([str(x) for x in pNew]),
735 ),
736 [sNew, pNew, kNew],
737 )
738 )
739 elif (n % sparsity == 0
Les Bell7aa69f42021-09-20 10:44:07 +0100740 # padding must not exceed the kernel size
741 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
742 # the padded shape must exceed the kernel size
743 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100744 and y < shape[1] and x < shape[2]
Les Bell7aa69f42021-09-20 10:44:07 +0100745 ):
746 arg_list.append(
747 (
748 "st{}_kern{}_pad{}".format(
749 "".join([str(x) for x in s]),
750 "".join([str(x) for x in k]),
751 "".join([str(x) for x in p]),
752 ),
753 [s, p, k],
754 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800755 )
Les Bell7aa69f42021-09-20 10:44:07 +0100756 n += 1
757
Eric Kunzee5e26762020-10-13 16:11:07 -0700758 return arg_list
759
760 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100761 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700762 arg_list = []
763
764 # Enumerate the output types here
765 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800766 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700767 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800768 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700769 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800770 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700771 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800772 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700773 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800774 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700775 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800776 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700777
778 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800779 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700780
781 return arg_list
782
783 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100784 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700785 arg_list = []
786
787 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100788 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100789 if dtype in [DType.UINT8, DType.INT8] and error_name == ErrorIf.OutputZeroPointNotZero:
790 continue
791 if inDtype == DType.UINT8 and dtype != DType.INT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100792 # The only output dtype for UINT8 is INT8, skip all other combinations
793 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100794 if inDtype != DType.INT8 and dtype == DType.UINT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100795 # The only input dtype for UINT8 is INT8, skip all other combinations
796 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100797 if error_name == ErrorIf.WrongOutputType and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype):
798 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100799
Kevin Cheng550ccc52021-03-03 11:21:43 -0800800 for scale32 in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100801 if error_name == ErrorIf.ScaleTrue and scale32 == False:
802 continue
803 elif error_name == ErrorIf.ScaleNotTrue and scale32 == True:
804 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800805 for double_round in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100806 if error_name == ErrorIf.ScaleNotTrue and double_round == False:
807 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800808 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700809
Matthew Haddonc2025212021-10-08 21:21:05 +0100810 if inDtype == DType.INT48 and scale32 and error_name != ErrorIf.ScaleTrue:
Eric Kunzee5e26762020-10-13 16:11:07 -0700811 # Illegal condition. Must be scale32=False
812 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100813 if double_round and not scale32 and error_name != ErrorIf.ScaleNotTrue:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100814 # Illegal condition. ERROR_IF(!scale32 && double_round)
815 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700816
Kevin Cheng550ccc52021-03-03 11:21:43 -0800817 arg_list.append(
818 (
819 "out{}_sc{}_dr{}_pc{}".format(
820 DTypeNames[dtype],
821 int(scale32),
822 int(double_round),
823 int(per_channel),
824 ),
825 [dtype, scale32, double_round, per_channel],
826 )
827 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700828
829 return arg_list
830
Kevin Chengaee1fac2020-11-11 13:54:06 -0800831 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100832 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800833 arg_list = []
834
835 if dtype is DType.INT32:
836 for p in range(testGen.args.num_rand_permutations):
837
838 shift = testGen.randInt(0, 32)
839
Kevin Cheng550ccc52021-03-03 11:21:43 -0800840 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800841 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100842 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800843
844 return arg_list
845
846 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100847 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800848 arg_list = []
849
Kevin Cheng550ccc52021-03-03 11:21:43 -0800850 arg_list.append(("roundTrue", [True]))
851 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800852
853 return arg_list
854
Eric Kunzee5e26762020-10-13 16:11:07 -0700855 # Helper function for reshape. Gets some factors of a larger number.
856 @staticmethod
857 def getFactors(val, start=1):
858 factors = []
859
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100860 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700861 if (val % i) == 0:
862 factors.append(i)
863
864 return factors
865
866 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100867 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700868 arg_list = []
869
870 origShape = shapeList[0]
871
872 totalElements = 1
873 for s in origShape:
874 totalElements *= s
875
876 # This code is NOT fast. Fortunately, the numbers are fairly small.
877 factors = TosaArgGen.getFactors(totalElements)
878
879 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100880 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800881 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700882 continue
883
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100884 found = True
885 # escape_counter breaks while loop if it continues on for too long
886 escape_counter = 0
887 while found:
888 newShape = []
889 # Generate newShape ensuring it isn't a duplicate
890 remainingElements = totalElements
891 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100892 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100893 # pick rank-1 factors
894 newShape.append(shuffledFactors[0])
895 remainingElements = remainingElements // shuffledFactors[0]
896 shuffledFactors = testGen.rng.permutation(
897 TosaArgGen.getFactors(remainingElements)
898 )
899 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700900
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100901 # Toss in a -1 sometimes
902 minusOne = testGen.randInt(0, newRank * 4)
903 if minusOne < newRank:
904 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700905
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100906 # Check for duplicates
907 found = False
908 for name, other_shape in arg_list:
909 if other_shape[0] == newShape:
910 found = True
911 break
912
913 escape_counter += 1
914 if escape_counter >= 100:
915 break
916
917 if not found:
918 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700919
920 return arg_list
921
Eric Kunzee5e26762020-10-13 16:11:07 -0700922 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100923 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700924 arg_list = []
925
926 ifm_shape = shapeList[0]
927
Matthew Haddone807aae2021-10-11 18:12:58 +0100928
929 if error_name == ErrorIf.IndexOutsideBounds:
930 incorrect_large_index = range(len(ifm_shape)+1, 2*len(ifm_shape)+1)
931 incorrect_small_index = range(-len(ifm_shape), 0)
932 permutations = [p for p in itertools.permutations(incorrect_large_index)]
933 permutations.extend([p for p in itertools.permutations(incorrect_small_index)])
934 elif error_name == ErrorIf.IndexUsedTwice:
935 # Create list with a duplicated index
936 perm_range = list(range(len(ifm_shape)))
937 index_choice = testGen.rng.choice(range(len(perm_range)))
938 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
939 permutations = [p for p in itertools.permutations(perm_range)]
940
941
942 else:
943 # Get all permutations
944 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700945
Jeremy Johnsona6185572021-06-21 15:55:35 +0100946 # Limit to possible permutations from shape dimension or argument setting
947 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700948
Jeremy Johnsona6185572021-06-21 15:55:35 +0100949 # Get random permutation generator that uses all permutations
950 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700951
Jeremy Johnsona6185572021-06-21 15:55:35 +0100952 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -0700953 arg_list = [
954 ("perm{}".format(p), [random_permutations[p].tolist()])
955 for p in range(limit)
956 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700957 return arg_list
958
959 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100960 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700961 arg_list = []
962
963 ifm_shape = shapeList[0]
964 rank = len(ifm_shape)
965
966 for p in range(testGen.args.num_rand_permutations):
Matthew Haddone807aae2021-10-11 18:12:58 +0100967 start = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700968 size = []
969
Kevin Cheng550ccc52021-03-03 11:21:43 -0800970 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700971
972 for i in range(rank):
973 if ifm_shape[i] > 1:
Matthew Haddone807aae2021-10-11 18:12:58 +0100974 start.append(testGen.randInt(0, ifm_shape[i]))
975 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700976
977 # Invalid slice size?
978 if size[i] == 0:
979 valid = False
980 else:
Matthew Haddone807aae2021-10-11 18:12:58 +0100981 start.append(0)
Eric Kunzee5e26762020-10-13 16:11:07 -0700982 size.append(1)
983
984 if valid:
Matthew Haddone807aae2021-10-11 18:12:58 +0100985 # If ERROR_IF test required then incorrect start, size will be returned
986 start, size = TosaErrorIfArgGen.eiSliceErrorIf(testGen, error_name, ifm_shape, start, size)
987 arg_list.append(("perm{}".format(p), [start, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700988 return arg_list
989
990 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100991 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700992 arg_list = []
993
994 ifm_shape = shapeList[0]
995 rank = len(ifm_shape)
996
997 for p in range(testGen.args.num_rand_permutations):
998
999 # Pick a few random, but small multiple values
1000 # because otherwise this has a tendency to generate
1001 # enormous tensors
1002 multiples = []
1003 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +01001004 if ifm_shape[i] > 1000:
1005 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
1006 multiples.append(1)
1007 elif max(ifm_shape) > 1000:
1008 multiples.append(2)
1009 else:
1010 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001011 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001012
1013 return arg_list
1014
1015 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001016 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001017 arg_list = []
1018
1019 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +01001020 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001021
1022 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +01001023 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001024 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +01001025 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001026 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +01001027 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001028 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +01001029 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001030 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001031 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001032 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +01001033 elif error_name == ErrorIf.WrongInputType:
1034 # If an incorrect input type is used then we set a 'correct'
1035 # output type to avoid other errors
1036 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -07001037 else:
1038 continue
1039
1040 for outputDType in outputDTypeList:
1041 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -07001042 # Randomly generate legal output dimensions and shift
1043 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +01001044 # A output_dim of 1 will cause offset to exceed allowed range
1045 # so minimum value 2 produced below
1046 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
1047 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
1048 output_dims[0] += 1
1049 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
1050 output_dims[1] += 1
1051
Kevin Cheng77d0f762020-11-24 10:26:32 -08001052 in_center_h = (ifm_shape[1] - 1) / 2.0
1053 in_center_w = (ifm_shape[2] - 1) / 2.0
1054 out_center_h = (output_dims[0] - 1) / 2.0
1055 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -07001056
Kevin Cheng77d0f762020-11-24 10:26:32 -08001057 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
1058 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
1059 fp_offset_y = in_center_h - fp_stride_y * out_center_h
1060 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -07001061
Kevin Cheng77d0f762020-11-24 10:26:32 -08001062 if outputDType == DType.FLOAT:
1063 shift = 0
1064 stride = [0, 0]
1065 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001066 stride_fp = [fp_stride_y, fp_stride_x]
1067 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +01001068
1069 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001070 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001071 testGen,
1072 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001073 mode,
1074 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001075 shapeList,
1076 outputDType,
1077 shift,
1078 stride,
1079 stride_fp,
1080 offset,
1081 offset_fp
1082 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001083 else:
1084 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001085
Kevin Cheng550ccc52021-03-03 11:21:43 -08001086 arg_list.append(
1087 (
1088 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001089 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001090 output_dims[0],
1091 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001092 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001093 stride_fp[0],
1094 stride_fp[1],
1095 offset_fp[0],
1096 offset_fp[1],
1097 ),
1098 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001099 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001100 stride,
1101 offset,
1102 shift,
1103 stride_fp,
1104 offset_fp,
1105 output_dims,
1106 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001107 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001108 ],
1109 )
1110 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001111 else:
1112 shift = 11
1113 unit = float(1 << shift)
1114 stride_y = int(round(fp_stride_y * unit))
1115 stride_x = int(round(fp_stride_x * unit))
1116 offset_y = int(round(fp_offset_y * unit))
1117 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001118
Kevin Cheng550ccc52021-03-03 11:21:43 -08001119 while (
Matthew Haddone86fd342021-09-07 16:12:21 +01001120 stride_y >= (16 << shift)
1121 or stride_x >= (16 << shift)
1122 or offset_y >= (16 << shift)
1123 or offset_x >= (16 << shift)
1124 or offset_y <= (-16 << shift)
1125 or offset_x <= (-16 << shift)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001126 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001127 shift = shift - 1
1128 unit = float(1 << shift)
1129 stride_y = int(round(fp_stride_y * unit))
1130 stride_x = int(round(fp_stride_x * unit))
1131 offset_y = int(round(fp_offset_y * unit))
1132 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001133
Kevin Cheng550ccc52021-03-03 11:21:43 -08001134 stride = [stride_y, stride_x]
1135 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001136
1137 stride_fp = [0.0, 0.0]
1138 offset_fp = [0.0, 0.0]
1139
Matthew Haddone86fd342021-09-07 16:12:21 +01001140 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001141 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001142 testGen,
1143 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001144 mode,
1145 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001146 shapeList,
1147 outputDType,
1148 shift,
1149 stride,
1150 stride_fp,
1151 offset,
1152 offset_fp
1153 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001154 else:
1155 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001156
Kevin Cheng550ccc52021-03-03 11:21:43 -08001157 arg_list.append(
1158 (
1159 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001160 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001161 shift,
1162 output_dims[0],
1163 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001164 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001165 stride[0],
1166 stride[1],
1167 offset[0],
1168 offset[1],
1169 ),
1170 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001171 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001172 stride,
1173 offset,
1174 shift,
1175 stride_fp,
1176 offset_fp,
1177 output_dims,
1178 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001179 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001180 ],
1181 )
1182 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001183
1184 return arg_list
1185
Matthew Haddon1c00b712021-10-01 15:51:03 +01001186 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001187 # CondIf generates the condition values here.
1188 # Convert to tensors in the build function, along with the
1189 # then and else blocks
1190 arg_list = []
1191
1192 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001193 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001194
1195 return arg_list
1196
Matthew Haddon1c00b712021-10-01 15:51:03 +01001197 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001198 # While loop: 0 iterations, 1, more than 1
1199 arg_list = []
1200
1201 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001202 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001203
1204 return arg_list
1205
Matthew Haddone86fd342021-09-07 16:12:21 +01001206class TosaErrorIfArgGen:
1207
1208 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001209 def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
Matthew Haddone86fd342021-09-07 16:12:21 +01001210
1211 if outputDType == DType.FLOAT:
1212 if error_name == ErrorIf.StrideSmallerEqualZero:
1213 stride_fp = testGen.rng.random(size=[2]) - 2
1214 elif error_name == ErrorIf.ShiftNotZero:
1215 shift = testGen.rng.integers(1, 5)
1216 elif error_name == ErrorIf.StrideLargerDimension:
1217 shape = shapeList[0]
1218 transform_height = testGen.rng.choice([False, True])
1219 if transform_height:
1220 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1221 else:
1222 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1223 else:
1224 if error_name == ErrorIf.StrideSmallerEqualZero:
1225 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1226 elif error_name == ErrorIf.ShiftSmallerOne:
1227 shift = testGen.rng.integers(-3, 1)
1228 if shift <= 0:
1229 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1230 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1231 else:
1232 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1233 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1234 elif error_name == ErrorIf.ShiftLargerEleven:
1235 shift = np.int16(testGen.rng.integers(12, 15))
1236 elif error_name == ErrorIf.StrideLargerDimension:
1237 shape = shapeList[0]
1238 transform_height = testGen.rng.choice([False, True])
1239 if transform_height:
1240 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1241 else:
1242 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1243 elif error_name == ErrorIf.StrideLargerEqualMax:
1244 stride = [(16 << shift) + 1, (16 << shift) + 1]
1245 elif error_name == ErrorIf.OffsetLargerEqualMax:
1246 offset = [(16 << shift) + 1, (16 << shift) + 1]
1247 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1248 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1249
Matthew Haddon1c00b712021-10-01 15:51:03 +01001250
Matthew Haddon848efb42021-09-09 12:30:53 +01001251 if error_name == ErrorIf.WrongOutputType:
1252 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1253 incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1254 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1255 incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
1256 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1257 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
1258 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1259 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
1260 elif dtype == DType.FLOAT:
1261 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
1262 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001263
Matthew Haddon848efb42021-09-09 12:30:53 +01001264 return shift, stride, stride_fp, offset, offset_fp, outputDType
1265
Matthew Haddone807aae2021-10-11 18:12:58 +01001266
Matthew Haddon848efb42021-09-09 12:30:53 +01001267 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001268 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
1269 if (error_name == ErrorIf.StrideSmallerOne
1270 # padding must not exceed the kernel size
1271 and pad[0] < kernel[0] and pad[1] < kernel[0] and pad[2] < kernel[1] and pad[3] < kernel[1]):
1272 wrongStride = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1273 return wrongStride, pad, kernel
1274 elif error_name == ErrorIf.PadSmallerZero:
1275 wrongPad = (testGen.rng.choice([-1, -2, -3]),
1276 testGen.rng.choice([-1, -2, -3]),
1277 testGen.rng.choice([-1, -2, -3]),
1278 testGen.rng.choice([-1, -2, -3]))
1279 return stride, wrongPad, kernel
1280 elif error_name == ErrorIf.KernelSmallerOne:
1281 wrongKernel = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1282 return stride, pad, wrongKernel
1283 elif error_name == ErrorIf.PadLargerEqualKernel:
1284 wrongPad = (testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1285 testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1286 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]),
1287 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]))
1288 return stride, wrongPad, kernel
1289 else:
1290 return None, None, None
1291
Matthew Haddone807aae2021-10-11 18:12:58 +01001292
Matthew Haddonc2025212021-10-08 21:21:05 +01001293 @staticmethod
1294 def eiRescaleWrongOutputType(input_dtype, output_dtype):
1295 if input_dtype == DType.INT8:
1296 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1297 return True
1298 if input_dtype in [DType.INT16, DType.INT32]:
1299 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1300 return True
1301 elif input_dtype == DType.INT48:
1302 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1303 return True
1304 elif input_dtype == DType.UINT8:
1305 if output_dtype != DType.INT8:
1306 return True
1307 return False
1308
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001309
1310 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001311 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1312 # Mess up input/output tensors for ERROR_IF checks
1313 if error_name == "WrongInputList":
1314 add_input = testGen.rng.choice([True, False])
1315 if add_input:
1316 input_list.append('eiDummyInput')
1317 else:
1318 input_list = input_list[:-1]
1319 if error_name == "WrongOutputList":
1320 add_output = testGen.rng.choice([True, False])
1321 if add_output:
1322 output_list.append('eiDummyOutput')
1323 else:
1324 output_list = []
1325 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001326
Matthew Haddone807aae2021-10-11 18:12:58 +01001327
Matthew Haddonc2025212021-10-08 21:21:05 +01001328 @staticmethod
1329 def eiRestrictDimension(shape, error_name):
1330 # Restrict dimension size if rank is large for WrongRank Error_If
1331 # This will keep the test sizes reasonably small
1332 if error_name == ErrorIf.WrongRank:
1333 if len(shape) > 4:
1334 shape[4] = 1
1335
1336 return shape
1337
Matthew Haddone807aae2021-10-11 18:12:58 +01001338
1339 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
1340 if error_name == ErrorIf.StartSmallerZero:
1341 newStart = []
1342 for i in range(len(input_shape)):
1343 newStart.append(testGen.rng.choice([-3, -2, -1]))
1344 return newStart, size
1345 elif error_name == ErrorIf.SizeSmallerEqualZero:
1346 newSize = []
1347 for i in range(len(input_shape)):
1348 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
1349 return start, newSize
1350 elif error_name == ErrorIf.StartSizeOutsideBounds:
1351 newStart, newSize = [], []
1352 for i in range(len(input_shape)):
1353 newStart.append(input_shape[i]-1)
1354 newSize.append(testGen.rng.choice([2, 3, 4]))
1355 return newStart, newSize
1356 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
1357 remove = testGen.rng.choice([True, False])
1358 if remove:
1359 newStart = start[1:]
1360 newSize = size[1:]
1361 else:
1362 newStart = start
1363 newStart.append(1)
1364 newSize = size
1365 newSize.append(1)
1366 return newStart, newSize
1367 else:
1368 return start, size
1369
Matthew Haddone86fd342021-09-07 16:12:21 +01001370class TosaErrorValidator:
1371
Matthew Haddon848efb42021-09-09 12:30:53 +01001372 @staticmethod
1373 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
1374 # Check ERROR_IF statements
1375
1376 for val_fcn in validator_fcns:
1377 val_result = val_fcn(True, **kwargs)
1378
1379 validator_name = val_result['error_name']
1380 error_result = val_result['error_result']
1381 error_reason = val_result['error_reason']
1382
1383 if error_result:
1384 if error_name == validator_name:
1385 serializer.setExpectedReturnCode(2, error_reason)
1386 else:
1387 print(f"Multiple ERROR_IF checks hit \nError required: {error_name}, Error_produced: {validator_name}")
1388 return None # Return None to delete test if wrong ERROR_IF is hit
1389 else:
1390 if error_name == validator_name:
1391 print(f"No ERROR_IF hit for {error_name}")
1392 return None
1393
1394 @staticmethod
1395 def evWrongInputType(check=False, **kwargs):
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001396 all_dtypes = {DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
Matthew Haddon848efb42021-09-09 12:30:53 +01001397
1398 # Find the unsupported input data types
1399 assert 'op' in kwargs
1400 op = kwargs['op']
1401 input_dtypes = op['types']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001402
1403 allowed_input_dtypes = {t[0] if isinstance(t, list) else t for t in input_dtypes}
1404 wrong_input_dtypes = list(all_dtypes - allowed_input_dtypes)
Matthew Haddon848efb42021-09-09 12:30:53 +01001405
1406 error_name = ErrorIf.WrongInputType
1407 param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
1408 error_result = False
1409 error_reason = "Input data type not supported for this operator"
1410
1411 if check:
1412 input_dtype = kwargs['input_dtype']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001413 if op['op'] == Op.FULLY_CONNECTED:
1414 if input_dtype not in allowed_input_dtypes:
1415 error_result = True
1416 elif input_dtype not in input_dtypes:
Matthew Haddon848efb42021-09-09 12:30:53 +01001417 error_result = True
1418
1419 info_dict = {
1420 "error_name": error_name,
1421 "error_result": error_result,
1422 "error_reason": error_reason,
1423 "param_reqs": param_reqs
1424 }
1425 return info_dict
1426
1427 @staticmethod
1428 def evWrongOutputType(check=False, **kwargs):
1429 error_name = ErrorIf.WrongOutputType
1430 param_reqs = {"rank": None, "dtype": None, "shape": None}
1431 error_result = False
1432 error_reason = "Output data type not supported for this configuration of operator"
1433
1434 if check:
1435 input_dtype = kwargs['input_dtype']
1436 output_dtype = kwargs['output_dtype']
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001437 op = kwargs['op']
Matthew Haddon848efb42021-09-09 12:30:53 +01001438
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001439 if op['op'] == Op.RESIZE:
1440 mode = kwargs['mode']
1441 if (
1442 (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
1443 (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
1444 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1445 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1446 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1447 ):
1448 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001449 elif op['op'] == Op.RESCALE:
1450 if input_dtype == DType.INT8:
1451 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1452 error_result = True
1453 if input_dtype in [DType.INT16, DType.INT32]:
1454 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1455 error_result = True
1456 elif input_dtype == DType.INT48:
1457 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1458 error_result = True
1459 elif input_dtype == DType.UINT8:
1460 if output_dtype != DType.INT8:
1461 error_result = True
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001462 elif op['op'] in [Op.FULLY_CONNECTED, Op.MATMUL]:
1463 if (
1464 (input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1465 (input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1466 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1467 ):
1468 error_result = True
1469 elif op['op'] == Op.ARGMAX:
1470 if input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] and output_dtype != DType.INT32:
1471 error_result = True
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001472 else:
1473 if output_dtype != input_dtype:
1474 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001475
1476 info_dict = {
1477 "error_name": error_name,
1478 "error_result": error_result,
1479 "error_reason": error_reason,
1480 "param_reqs": param_reqs
1481 }
1482 return info_dict
1483
1484 @staticmethod
1485 def evWrongRank(check=False, **kwargs):
1486 all_ranks = (1, 2, 3, 4, 5)
1487
1488 # Make a list of incorrect ranks
1489 assert 'op' in kwargs
1490 op = kwargs['op']
1491 rmin, rmax = op['rank']
1492 rank_range = range(rmin, rmax + 1)
1493 incorrect_ranks = list(set(all_ranks) - set(rank_range))
Matthew Haddonc2025212021-10-08 21:21:05 +01001494 # Remove small incorrect ranks to avoid index errors
1495 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
Matthew Haddon848efb42021-09-09 12:30:53 +01001496 # Set minimum incorrect rank to 3 to avoid index error
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001497 if op['op'] in [Op.RESIZE]:
Matthew Haddon848efb42021-09-09 12:30:53 +01001498 incorrect_ranks = [3, 5]
1499
1500 error_name = ErrorIf.WrongRank
1501 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1502 error_result = False
1503 error_reason = "Rank not supported for this operator"
1504
1505 if check:
1506 input_shape = kwargs['input_shape']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001507
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001508 if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4:
Matthew Haddon848efb42021-09-09 12:30:53 +01001509 error_result = True
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001510 elif op['op'] == Op.FULLY_CONNECTED and len(input_shape) != 2:
1511 error_result = True
1512 elif op['op'] == Op.MATMUL and len(input_shape) != 3:
1513 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001514 else:
1515 if len(input_shape) not in rank_range:
1516 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001517
1518 info_dict = {
1519 "error_name": error_name,
1520 "error_result": error_result,
1521 "error_reason": error_reason,
1522 "param_reqs": param_reqs
1523 }
1524 return info_dict
1525
1526 @staticmethod
1527 def evWrongInputList(check=False, **kwargs):
1528 error_name = ErrorIf.WrongInputList
1529 param_reqs = {"rank": None, "dtype": None, "shape": None}
1530 error_result = False
1531 error_reason = "Op input list does not match expected input"
1532
1533 if check:
1534 op = kwargs['op']
1535 input_list = kwargs['input_list']
1536 num_operands = kwargs['num_operands']
Matthew Haddone807aae2021-10-11 18:12:58 +01001537 # both PAD, TRANSPOSE add an extra const layer in the build function
1538 if op['op'] in [Op.PAD, Op.TRANSPOSE]:
1539 if len(input_list) != num_operands + 1:
1540 error_result = True
1541 else:
1542 if len(input_list) != num_operands:
1543 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001544
1545 info_dict = {
1546 "error_name": error_name,
1547 "error_result": error_result,
1548 "error_reason": error_reason,
1549 "param_reqs": param_reqs
1550 }
1551 return info_dict
1552
1553 @staticmethod
1554 def evWrongOutputList(check=False, **kwargs):
1555 error_name = ErrorIf.WrongOutputList
1556 param_reqs = {"rank": None, "dtype": None, "shape": None}
1557 error_result = False
1558 error_reason = "Op output list does not match expected output"
1559
1560 if check:
1561 output_list = kwargs['output_list']
1562 # Note this will be incorrect if an operator returns more than one output
1563 if len(output_list) != 1:
1564 error_result = True
1565
1566 info_dict = {
1567 "error_name": error_name,
1568 "error_result": error_result,
1569 "error_reason": error_reason,
1570 "param_reqs": param_reqs
1571 }
1572 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01001573
1574 @staticmethod
1575 def evMaxDimExceeded(check=False, **kwargs):
1576 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01001577 param_reqs = {
1578 "rank": [4,4],
1579 "dtype": [DType.INT8],
1580 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
1581 }
Matthew Haddone86fd342021-09-07 16:12:21 +01001582 error_result = False
1583 error_reason = "At least one maximum dimension is larger than 16384"
1584
1585 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001586 input_shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001587 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
1588 if ((input_shape[1] > 16384) or
1589 (input_shape[2] > 16384) or
1590 (output_shape[0] > 16384) or
1591 (output_shape[1] > 16384)):
1592 error_result = True
1593
1594 info_dict = {
1595 "error_name": error_name,
1596 "error_result": error_result,
1597 "error_reason": error_reason,
1598 "param_reqs": param_reqs
1599 }
1600 return info_dict
1601
1602 @staticmethod
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001603 def evBatchMismatch(check=False, **kwargs):
1604 error_name = ErrorIf.BatchMismatch
1605 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1606 error_result = False
1607 error_reason = "Input batch size not equal to output batch size"
1608
1609 assert 'op' in kwargs
1610 op = kwargs['op']
1611 rmin, rmax = op['rank']
1612 rank_range = range(rmin, rmax + 1)
1613
1614 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001615 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001616 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1617
1618 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
1619 error_result = True
1620
1621 info_dict = {
1622 "error_name": error_name,
1623 "error_result": error_result,
1624 "error_reason": error_reason,
1625 "param_reqs": param_reqs
1626 }
1627 return info_dict
1628
1629 @staticmethod
1630 def evChannelMismatch(check=False, **kwargs):
1631 error_name = ErrorIf.ChannelMismatch
1632 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1633 error_result = False
1634 error_reason = "Input channel size not equal to output channel size"
1635
1636 assert 'op' in kwargs
1637 op = kwargs['op']
1638 rmin, rmax = op['rank']
1639 rank_range = range(rmin, rmax + 1)
1640
1641 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001642 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001643 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1644 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
1645 error_result = True
1646
1647 info_dict = {
1648 "error_name": error_name,
1649 "error_result": error_result,
1650 "error_reason": error_reason,
1651 "param_reqs": param_reqs
1652 }
1653 return info_dict
1654
1655 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001656 def evStrideSmallerEqualZero(check=False, **kwargs):
1657 error_name = ErrorIf.StrideSmallerEqualZero
1658 param_reqs = {"rank": None, "dtype": None, "shape": None}
1659 error_result = False
1660 error_reason = "Stride value smaller than or equal zero"
1661
1662 if check:
1663 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001664 output_dtype = kwargs['output_dtype']
1665 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
1666 stride = kwargs['stride'] # Work around wrong input/output type tests
1667 elif output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001668 stride = kwargs['stride_fp']
Matthew Haddon848efb42021-09-09 12:30:53 +01001669 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1670 stride = kwargs['stride_fp'] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01001671 else:
1672 stride = kwargs['stride']
1673
1674 if min(stride) <= 0:
1675 error_result = True
1676
1677 info_dict = {
1678 "error_name": error_name,
1679 "error_result": error_result,
1680 "error_reason": error_reason,
1681 "param_reqs": param_reqs
1682 }
1683 return info_dict
1684
1685 @staticmethod
1686 def evStrideLargerEqualMax(check=False, **kwargs):
1687 error_name = ErrorIf.StrideLargerEqualMax
1688 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1689 error_result = False
1690 error_reason = "Stride value larger than or equal to maximum value"
1691
1692 if check:
1693 shift = kwargs['shift']
1694 input_dtype = kwargs['input_dtype']
1695 stride = kwargs['stride']
1696 if input_dtype in [DType.INT8, DType.INT16]:
1697 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1698 error_result = True
1699 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1700 error_result = True
1701
1702 info_dict = {
1703 "error_name": error_name,
1704 "error_result": error_result,
1705 "error_reason": error_reason,
1706 "param_reqs": param_reqs
1707 }
1708 return info_dict
1709
1710
1711 @staticmethod
1712 def evStrideLargerDimension(check=False, **kwargs):
1713 error_name = ErrorIf.StrideLargerDimension
1714 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1715 error_result = False
1716 error_reason = "Stride value larger than or equal to H/W dimension"
1717
1718 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001719 shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001720 input_dtype = kwargs['input_dtype']
1721 stride = kwargs['stride_fp']
1722
1723 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1724 error_result = True
1725
1726 info_dict = {
1727 "error_name": error_name,
1728 "error_result": error_result,
1729 "error_reason": error_reason,
1730 "param_reqs": param_reqs
1731 }
1732 return info_dict
1733
1734
1735 @staticmethod
1736 def evOffsetSmallerEqualMin(check=False, **kwargs):
1737 error_name = ErrorIf.OffsetSmallerEqualMin
1738 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1739 error_result = False
1740 error_reason = "Offset value smaller than or equal to minimum value"
1741
1742 if check:
1743 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001744 output_dtype = kwargs['output_dtype']
1745 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001746 offset = kwargs['offset_fp']
1747 else:
1748 offset = kwargs['offset']
1749
1750 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
1751 error_result = True
1752 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
1753 error_result = True
1754
1755 info_dict = {
1756 "error_name": error_name,
1757 "error_result": error_result,
1758 "error_reason": error_reason,
1759 "param_reqs": param_reqs
1760 }
1761 return info_dict
1762
1763 @staticmethod
1764 def evOffsetLargerEqualMax(check=False, **kwargs):
1765 error_name = ErrorIf.OffsetLargerEqualMax
1766 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1767 error_result = False
1768 error_reason = "Offset value larger than or equal to maximum value"
1769
1770 if check:
1771 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001772 output_dtype = kwargs['output_dtype']
1773 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001774 offset = kwargs['offset_fp']
1775 else:
1776 offset = kwargs['offset']
1777
1778 if shift >= 0:
1779 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
1780 error_result = True
1781
1782 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
1783 error_result = True
1784 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
1785 error_result = True
1786
1787 info_dict = {
1788 "error_name": error_name,
1789 "error_result": error_result,
1790 "error_reason": error_reason,
1791 "param_reqs": param_reqs
1792 }
1793 return info_dict
1794
1795 @staticmethod
1796 def evShiftNotZero(check=False, **kwargs):
1797 error_name = ErrorIf.ShiftNotZero
1798 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1799 error_result = False
1800 error_reason = "Shift value must be zero for float input"
1801
1802 if check:
1803 shift = kwargs['shift']
1804 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001805 output_dtype = kwargs['output_dtype']
1806 if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
Matthew Haddone86fd342021-09-07 16:12:21 +01001807 error_result = True
1808
1809 info_dict = {
1810 "error_name": error_name,
1811 "error_result": error_result,
1812 "error_reason": error_reason,
1813 "param_reqs": param_reqs
1814 }
1815 return info_dict
1816
1817
1818 @staticmethod
1819 def evShiftSmallerOne(check=False, **kwargs):
1820 error_name = ErrorIf.ShiftSmallerOne
1821 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1822 error_result = False
1823 error_reason = "Shift value smaller than one"
1824
1825 if check:
1826 shift = kwargs['shift']
1827 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001828 output_dtype = kwargs['output_dtype']
1829 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001830 error_result = True
1831
1832 info_dict = {
1833 "error_name": error_name,
1834 "error_result": error_result,
1835 "error_reason": error_reason,
1836 "param_reqs": param_reqs
1837 }
1838 return info_dict
1839
1840 @staticmethod
1841 def evShiftLargerEleven(check=False, **kwargs):
1842 error_name = ErrorIf.ShiftLargerEleven
1843 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1844 error_result = False
1845 error_reason = "Shift value larger than eleven"
1846
1847 if check:
1848 shift = kwargs['shift']
1849 if shift > 11:
1850 error_result = True
1851
1852 info_dict = {
1853 "error_name": error_name,
1854 "error_result": error_result,
1855 "error_reason": error_reason,
1856 "param_reqs": param_reqs
1857 }
1858 return info_dict
1859
1860
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001861 @staticmethod
1862 def evRankMismatch(check=False, **kwargs):
1863 error_name = ErrorIf.RankMismatch
1864 param_reqs = {"rank": None, "dtype": None, "shape": None}
1865 error_result = False
1866 error_reason = "Input Rank does not match output rank"
1867
1868 if check:
1869 input1_shape = kwargs['input1'].shape
1870 input2_shape = kwargs['input2'].shape
1871 output_shape = kwargs['result_tensor'].shape
1872 if (len(input1_shape) != len(output_shape)) or (len(input2_shape) != len(output_shape)):
1873 error_result = True
1874
1875 info_dict = {
1876 "error_name": error_name,
1877 "error_result": error_result,
1878 "error_reason": error_reason,
1879 "param_reqs": param_reqs
1880 }
1881 return info_dict
1882
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001883 @staticmethod
1884 def evInputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001885 op = kwargs['op']
1886 inputDtypes = op['types'].copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001887 # If inputDtypes is a list then only the first two elements are INT8 inputs
1888 if isinstance(inputDtypes, list):
1889 inputDtypes = inputDtypes[2:]
1890
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001891 if DType.INT8 in inputDtypes:
1892 inputDtypes.remove(DType.INT8)
1893 if DType.UINT8 in inputDtypes:
1894 inputDtypes.remove(DType.UINT8)
1895
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001896 error_name = ErrorIf.InputZeroPointNotZero
1897 param_reqs = {
1898 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001899 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001900 "shape": None
1901 }
1902 error_result = False
1903 error_reason = "Input DType not INT8 and zero point not 0"
1904
1905 if check:
1906 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01001907 if isinstance(kwargs['qinfo'], tuple):
1908 qinfo = kwargs['qinfo']
1909 input_zero_point = qinfo[0]
1910 else:
1911 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1912 qinfo = kwargs['qinfo'].ints
1913 input_zero_point = qinfo[0][1]
1914
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001915 if op['op'] == Op.MATMUL:
1916 input1_dtype = kwargs['input_dtype']
1917 input2_dtype = kwargs['input2_dtype']
1918 qinfo = kwargs['qinfo'].ints
1919 input1_zero_point = qinfo[0][1]
1920 input2_zero_point = qinfo[1][1]
1921 if (input1_dtype != DType.INT8 and input1_zero_point != 0) or (input2_dtype != DType.INT8 and input2_zero_point != 0):
1922 error_result = True
1923 else:
1924 if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
1925 error_result = True
1926
1927 info_dict = {
1928 "error_name": error_name,
1929 "error_result": error_result,
1930 "error_reason": error_reason,
1931 "param_reqs": param_reqs
1932 }
1933 return info_dict
1934
1935
1936 @staticmethod
1937 def evWeightZeroPointNotZero(check=False, **kwargs):
1938 op = kwargs['op']
1939
1940 # exclude inputs with INT8 weights
1941 inputDtypes = [t for t in op['types']
1942 if not isinstance(t, list) or t[1] != DType.INT8]
1943
1944 error_name = ErrorIf.WeightZeroPointNotZero
1945 param_reqs = {
1946 "rank": None,
1947 "dtype": inputDtypes,
1948 "shape": None
1949 }
1950 error_result = False
1951 error_reason = "Weight DType not INT8 and zero point not 0"
1952
1953 if check:
1954 weight_dtype = kwargs['weight_dtype']
1955 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
1956 qinfo = kwargs['qinfo'].ints
1957 weight_zero_point = qinfo[1][1]
1958 if weight_dtype != DType.INT8 and weight_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001959 error_result = True
1960
1961 info_dict = {
1962 "error_name": error_name,
1963 "error_result": error_result,
1964 "error_reason": error_reason,
1965 "param_reqs": param_reqs
1966 }
1967 return info_dict
1968
1969
1970 @staticmethod
1971 def evOutputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001972 op = kwargs['op']
1973 inputDtypes = op['types'].copy()
1974 if DType.INT8 in inputDtypes:
1975 inputDtypes.remove(DType.INT8)
1976 if DType.UINT8 in inputDtypes:
1977 inputDtypes.remove(DType.UINT8)
1978
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001979 error_name = ErrorIf.OutputZeroPointNotZero
1980 param_reqs = {
1981 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001982 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001983 "shape": None
1984 }
1985 error_result = False
1986 error_reason = "Output DType not INT8 and zero point not 0"
1987
1988 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001989 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01001990 output_dtype = kwargs['output_dtype']
1991 if isinstance(kwargs['qinfo'], tuple):
1992 qinfo = kwargs['qinfo']
1993 output_zero_point = qinfo[1]
1994 else:
1995 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1996 qinfo = kwargs['qinfo'].ints
1997 output_zero_point = qinfo[1][1]
1998 if op['op'] == Op.AVG_POOL2D:
1999 if input_dtype != DType.INT8 and output_zero_point != 0:
2000 error_result = True
2001 elif output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002002 error_result = True
2003
2004 info_dict = {
2005 "error_name": error_name,
2006 "error_result": error_result,
2007 "error_reason": error_reason,
2008 "param_reqs": param_reqs
2009 }
2010 return info_dict
2011
Matthew Haddond6ce7252021-09-29 15:35:44 +01002012 @staticmethod
2013 def evAxisSmallerZero(check=False, **kwargs):
2014 error_name = ErrorIf.AxisSmallerZero
2015 param_reqs = {"rank": None, "dtype": None, "shape": None}
2016 error_result = False
2017 error_reason = "Axis smaller than zero"
2018
2019 if check:
2020 axis = kwargs['axis']
2021 if axis < 0:
2022 error_result = True
2023
2024 info_dict = {
2025 "error_name": error_name,
2026 "error_result": error_result,
2027 "error_reason": error_reason,
2028 "param_reqs": param_reqs
2029 }
2030 return info_dict
2031
2032
2033 @staticmethod
2034 def evAxisLargerRank(check=False, **kwargs):
2035 error_name = ErrorIf.AxisLargerRank
2036 param_reqs = {"rank": None, "dtype": None, "shape": None}
2037 error_result = False
2038 error_reason = "Axis larger than rank"
2039
2040 if check:
2041 axis = kwargs['axis']
2042 shape = kwargs['input_shape']
2043 if axis > len(shape):
2044 error_result = True
2045
2046 info_dict = {
2047 "error_name": error_name,
2048 "error_result": error_result,
2049 "error_reason": error_reason,
2050 "param_reqs": param_reqs
2051 }
2052 return info_dict
2053
2054
2055 @staticmethod
2056 def evShapeOfAxisNotOne(check=False, **kwargs):
2057 error_name = ErrorIf.ShapeOfAxisNotOne
2058 param_reqs = {"rank": None, "dtype": None, "shape": None}
2059 error_result = False
2060 error_reason = "shape[axis] is not equal to 1"
2061
2062 if check:
2063 axis = kwargs['axis']
2064 shape = kwargs['output_shape']
2065 if (0 <= axis < len(shape)) and shape[axis] != 1:
2066 error_result = True
2067
2068 info_dict = {
2069 "error_name": error_name,
2070 "error_result": error_result,
2071 "error_reason": error_reason,
2072 "param_reqs": param_reqs
2073 }
2074 return info_dict
2075
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002076
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002077 @staticmethod
2078 def evPadSmallerZero(check=False, **kwargs):
2079 error_name = ErrorIf.PadSmallerZero
2080 param_reqs = {"rank": None, "dtype": None, "shape": None}
2081 error_result = False
2082 error_reason = "At least one pad is smaller than zero"
2083
2084 if check:
Matthew Haddone807aae2021-10-11 18:12:58 +01002085 op = kwargs['op']
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002086 pad = kwargs['pad']
Matthew Haddone807aae2021-10-11 18:12:58 +01002087 if op['op'] == Op.PAD:
2088 for padding in pad:
2089 if min(padding) < 0:
2090 error_result = True
2091 else:
2092 if min(pad) < 0:
2093 error_result = True
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002094
2095 info_dict = {
2096 "error_name": error_name,
2097 "error_result": error_result,
2098 "error_reason": error_reason,
2099 "param_reqs": param_reqs
2100 }
2101 return info_dict
2102
2103
2104 @staticmethod
2105 def evPadLargerEqualKernel(check=False, **kwargs):
2106 error_name = ErrorIf.PadLargerEqualKernel
2107 param_reqs = {"rank": None, "dtype": None, "shape": None}
2108 error_result = False
2109 error_reason = "At least one pad is larger than kernel dimension"
2110
2111 if check:
2112 pad = kwargs['pad']
2113 kernel = kwargs['kernel']
2114 if min(pad) > 0 and min(kernel) > 1:
2115 if pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]:
2116 error_result = True
2117
2118 info_dict = {
2119 "error_name": error_name,
2120 "error_result": error_result,
2121 "error_reason": error_reason,
2122 "param_reqs": param_reqs
2123 }
2124 return info_dict
2125
2126 @staticmethod
2127 def evPoolingOutputShapeMismatch(check=False, **kwargs):
2128 error_name = ErrorIf.PoolingOutputShapeMismatch
2129 param_reqs = {"rank": None, "dtype": None, "shape": None}
2130 error_result = False
2131 error_reason = "Mismatch between output shape provided and expected output shape"
2132
2133 if check:
2134 pad = kwargs['pad']
2135 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
2136
2137 kernel = kwargs['kernel']
2138 kernel_y, kernel_x = kernel[0], kernel[1]
2139
2140 input_shape = kwargs['input_shape']
2141 IH, IW = input_shape[1], input_shape[2]
2142
2143 output_shape = kwargs['output_shape']
2144 OH, OW = output_shape[1], output_shape[2]
2145
2146 stride = kwargs['stride']
2147 stride_y, stride_x = stride[0], stride[1]
2148
2149 # calculate correct height, width dimensions
2150 if stride_x != 0 and stride_y != 0:
2151 y_correct = (IH + pad_top + pad_bottom + stride_y - kernel_y) // stride_y
2152 x_correct = (IW + pad_left + pad_right + stride_x - kernel_x) // stride_x
2153
2154 # ensure parameters are valid
2155 params_valid = (min(kernel) >= 1 and min(stride) >= 1 and min(pad) >= 0
2156 and not (pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]))
2157
2158 if params_valid and (OH != y_correct or OW != x_correct):
2159 error_result = True
2160
2161 info_dict = {
2162 "error_name": error_name,
2163 "error_result": error_result,
2164 "error_reason": error_reason,
2165 "param_reqs": param_reqs
2166 }
2167 return info_dict
2168
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002169 @staticmethod
2170 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
2171 error_name = ErrorIf.ArgmaxOutputShapeMismatch
2172 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2173 error_result = False
2174 error_reason = "Mismatch between output shape provided and expected output shape"
2175
2176 if check:
2177 output_shape = kwargs['output_shape']
2178 input_shape = kwargs['input_shape']
2179 axis = kwargs['axis']
2180
2181 dimension_match = True
2182 axis_shift = 0
2183
2184 # Check that rank is correct before trying to check dimensions
2185 if (len(input_shape) - 1) == len(output_shape):
2186 for i in range(len(input_shape)):
2187 if i == axis:
2188 axis_shift = 1
2189 continue
2190 if input_shape[i] != output_shape[i - axis_shift]:
2191 dimension_match = False
2192
2193 if not dimension_match:
2194 error_result = True
2195
2196 info_dict = {
2197 "error_name": error_name,
2198 "error_result": error_result,
2199 "error_reason": error_reason,
2200 "param_reqs": param_reqs
2201 }
2202 return info_dict
2203
2204 @staticmethod
2205 def evArgmaxOutputRankMismatch(check=False, **kwargs):
2206 error_name = ErrorIf.ArgmaxOutputRankMismatch
2207 param_reqs = {"rank": None, "dtype": None, "shape": None}
2208 error_result = False
2209 error_reason = "Mismatch between output shape provided and expected output shape"
2210
2211 if check:
2212 output_shape = kwargs['output_shape']
2213 input_shape = kwargs['input_shape']
2214 axis = kwargs['axis']
2215 valid_params = axis >= 0 and axis < len(input_shape)
2216
2217 if valid_params and (len(input_shape) - 1) != len(output_shape):
2218 error_result = True
2219
2220 info_dict = {
2221 "error_name": error_name,
2222 "error_result": error_result,
2223 "error_reason": error_reason,
2224 "param_reqs": param_reqs
2225 }
2226 return info_dict
2227
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002228
2229 @staticmethod
2230 def evKernelSmallerOne(check=False, **kwargs):
2231 error_name = ErrorIf.KernelSmallerOne
2232 param_reqs = {"rank": None, "dtype": None, "shape": None}
2233 error_result = False
2234 error_reason = "At least one kernel dimension is smaller than zero"
2235
2236 if check:
2237 kernel = kwargs['kernel']
2238 if min(kernel) < 1:
2239 error_result = True
2240
2241 info_dict = {
2242 "error_name": error_name,
2243 "error_result": error_result,
2244 "error_reason": error_reason,
2245 "param_reqs": param_reqs
2246 }
2247 return info_dict
2248
2249 @staticmethod
2250 def evStrideSmallerOne(check=False, **kwargs):
2251 error_name = ErrorIf.StrideSmallerOne
2252 param_reqs = {"rank": None, "dtype": None, "shape": None}
2253 error_result = False
2254 error_reason = "At least one stride dimension is smaller than zero"
2255
2256 if check:
2257 stride = kwargs['stride']
2258 if min(stride) < 1:
2259 error_result = True
2260
2261 info_dict = {
2262 "error_name": error_name,
2263 "error_result": error_result,
2264 "error_reason": error_reason,
2265 "param_reqs": param_reqs
2266 }
2267 return info_dict
2268
Matthew Haddonc2025212021-10-08 21:21:05 +01002269 @staticmethod
2270 def evScaleTrue(check=False, **kwargs):
2271 error_name = ErrorIf.ScaleTrue
2272 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
2273 error_result = False
2274 error_reason = "Scale set to true but input type is INT48"
2275
2276 if check:
2277 input_dtype = kwargs['input_dtype']
2278 scale32 = kwargs['scale32']
2279 if scale32 and input_dtype == DType.INT48:
2280 error_result = True
2281
2282 info_dict = {
2283 "error_name": error_name,
2284 "error_result": error_result,
2285 "error_reason": error_reason,
2286 "param_reqs": param_reqs
2287 }
2288 return info_dict
2289
2290 @staticmethod
2291 def evScaleNotTrue(check=False, **kwargs):
2292 error_name = ErrorIf.ScaleNotTrue
2293 param_reqs = {"rank": None, "dtype": None, "shape": None}
2294 error_result = False
2295 error_reason = "Scale set to false but double round set to true"
2296
2297 if check:
2298 scale32 = kwargs['scale32']
2299 double_round = kwargs['double_round']
2300 if not scale32 and double_round:
2301 error_result = True
2302
2303 info_dict = {
2304 "error_name": error_name,
2305 "error_result": error_result,
2306 "error_reason": error_reason,
2307 "param_reqs": param_reqs
2308 }
2309 return info_dict
2310
Matthew Haddone807aae2021-10-11 18:12:58 +01002311 @staticmethod
2312 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
2313 error_name = ErrorIf.TensorSizeInputOutputMismatch
2314 param_reqs = {"rank": None, "dtype": None, "shape": None}
2315 error_result = False
2316 error_reason = "Input tensor size does not match output tensor size"
2317
2318 if check:
2319 input_shape = kwargs['input_shape']
2320 output_shape = kwargs['output_shape']
2321 input_size = np.prod(input_shape)
2322 output_size = np.prod(output_shape)
2323 if input_size != output_size:
2324 error_result = True
2325
2326 info_dict = {
2327 "error_name": error_name,
2328 "error_result": error_result,
2329 "error_reason": error_reason,
2330 "param_reqs": param_reqs
2331 }
2332 return info_dict
2333
2334 @staticmethod
2335 def evStartSmallerZero(check=False, **kwargs):
2336 error_name = ErrorIf.StartSmallerZero
2337 param_reqs = {"rank": None, "dtype": None, "shape": None}
2338 error_result = False
2339 error_reason = "Starting point smaller than zero"
2340
2341 if check:
2342 input_shape = kwargs['input_shape']
2343 start = kwargs['start']
2344 rank = len(input_shape)
2345 if len(start) == rank:
2346 for index in range(rank):
2347 if start[index] < 0:
2348 error_result = True
2349
2350 info_dict = {
2351 "error_name": error_name,
2352 "error_result": error_result,
2353 "error_reason": error_reason,
2354 "param_reqs": param_reqs
2355 }
2356 return info_dict
2357
2358
2359 @staticmethod
2360 def evSizeSmallerEqualZero(check=False, **kwargs):
2361 error_name = ErrorIf.SizeSmallerEqualZero
2362 param_reqs = {"rank": None, "dtype": None, "shape": None}
2363 error_result = False
2364 error_reason = "Size smaller than or equal to zero"
2365
2366 if check:
2367 input_shape = kwargs['input_shape']
2368 size = kwargs['size']
2369 rank = len(input_shape)
2370 if len(size) == rank:
2371 for index in range(rank):
2372 if size[index] <= 0:
2373 error_result = True
2374
2375 info_dict = {
2376 "error_name": error_name,
2377 "error_result": error_result,
2378 "error_reason": error_reason,
2379 "param_reqs": param_reqs
2380 }
2381 return info_dict
2382
2383
2384 @staticmethod
2385 def evStartSizeOutsideBounds(check=False, **kwargs):
2386 error_name = ErrorIf.StartSizeOutsideBounds
2387 param_reqs = {"rank": None, "dtype": None, "shape": None}
2388 error_result = False
2389 error_reason = "starting point plus size larger than input dimension"
2390
2391 if check:
2392 input_shape = kwargs['input_shape']
2393 start = kwargs['start']
2394 size = kwargs['size']
2395 rank = len(input_shape)
2396 if len(start) == rank and len(size) == rank:
2397 for index in range(rank):
2398 if start[index] + size[index] > input_shape[index]:
2399 error_result = True
2400
2401 info_dict = {
2402 "error_name": error_name,
2403 "error_result": error_result,
2404 "error_reason": error_reason,
2405 "param_reqs": param_reqs
2406 }
2407 return info_dict
2408
2409
2410 @staticmethod
2411 def evSizeOutputShapeMismatch(check=False, **kwargs):
2412 error_name = ErrorIf.SizeOutputShapeMismatch
2413 param_reqs = {"rank": None, "dtype": None, "shape": None}
2414 error_result = False
2415 error_reason = "Size does not match output dimension"
2416
2417 if check:
2418 input_shape = kwargs['input_shape']
2419 output_shape = kwargs['output_shape']
2420 size = kwargs['size']
2421 rank = len(input_shape)
2422 if len(size) == rank:
2423 for index in range(rank):
2424 if size[index] != output_shape[index]:
2425 error_result = True
2426
2427 info_dict = {
2428 "error_name": error_name,
2429 "error_result": error_result,
2430 "error_reason": error_reason,
2431 "param_reqs": param_reqs
2432 }
2433 return info_dict
2434
2435 @staticmethod
2436 def evInputSizeStartLengthMismatch(check=False, **kwargs):
2437 error_name = ErrorIf.InputSizeStartLengthMismatch
2438 param_reqs = {"rank": None, "dtype": None, "shape": None}
2439 error_result = False
2440 error_reason = "rank of input not equal to length of start or size"
2441
2442 if check:
2443 input_shape = kwargs['input_shape']
2444 start = kwargs['start']
2445 size = kwargs['size']
2446 rank = len(input_shape)
2447 if rank != len(start) or rank != len(size):
2448 error_result = True
2449
2450 info_dict = {
2451 "error_name": error_name,
2452 "error_result": error_result,
2453 "error_reason": error_reason,
2454 "param_reqs": param_reqs
2455 }
2456 return info_dict
2457
2458 @staticmethod
2459 def evIndexOutsideBounds(check=False, **kwargs):
2460 error_name = ErrorIf.IndexOutsideBounds
2461 param_reqs = {"rank": None, "dtype": None, "shape": None}
2462 error_result = False
2463 error_reason = "Index outside of allowed bounds"
2464
2465 if check:
2466 input_shape = kwargs['input_shape']
2467 perms = kwargs['perms']
2468 rank = len(input_shape)
2469
2470 for index in perms:
2471 if index < 0 or index > rank:
2472 error_result = True
2473
2474 info_dict = {
2475 "error_name": error_name,
2476 "error_result": error_result,
2477 "error_reason": error_reason,
2478 "param_reqs": param_reqs
2479 }
2480 return info_dict
2481
2482 @staticmethod
2483 def evIndexUsedTwice(check=False, **kwargs):
2484 error_name = ErrorIf.IndexUsedTwice
2485 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2486 error_result = False
2487 error_reason = "Index used multiple times"
2488
2489 if check:
2490 input_shape = kwargs['input_shape']
2491 perms = kwargs['perms']
2492 rank = len(input_shape)
2493
2494 unique_indices = []
2495 for index in perms:
2496 if index in unique_indices:
2497 error_result = True
2498 else:
2499 unique_indices.append(index)
2500
2501 info_dict = {
2502 "error_name": error_name,
2503 "error_result": error_result,
2504 "error_reason": error_reason,
2505 "param_reqs": param_reqs
2506 }
2507 return info_dict
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002508
2509
Matthew Haddonb724efc2021-08-25 16:40:29 +01002510class TosaInvalidValidator:
2511
2512 @staticmethod
2513 def ivWrongDataTypeOrModeResize(**kwargs):
2514 input_dtype = kwargs["input_dtype"]
2515 args = kwargs["args"]
2516 mode = args[0]
2517 stride = args[1]
2518 stride_fp = args[4]
2519 output_dtype = args[8]
2520
2521 if mode == ResizeMode.BILINEAR:
2522 # Invalid output data type / Invalid input datatype
2523 return (
2524 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
2525 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
2526 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
2527 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2528 )
2529 elif mode == ResizeMode.NEAREST:
2530 # Invalid output data type / Invalid input datatype
2531 return (
2532 (input_dtype != output_dtype) or
2533 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2534 )
2535 else:
2536 # Invalid resize mode
2537 return True
2538
2539 @staticmethod
2540 def ivBadStride(**kwargs):
2541 input_dtype = kwargs["input_dtype"]
2542 args = kwargs["args"]
2543 stride_x = args[1][0]
2544 stride_y = args[1][1]
2545 stride_fp_x = args[4][0]
2546 stride_fp_y = args[4][1]
2547
2548 if input_dtype == DType.FLOAT:
2549 if stride_fp_x <= 0 or stride_fp_y <= 0:
2550 # Negative or zero stride
2551 return True
2552 else:
2553 if stride_x <= 0 or stride_y <= 0:
2554 # Negative or zero stride
2555 return True
2556 return False
2557
2558
Matthew Haddonb724efc2021-08-25 16:40:29 +01002559 @staticmethod
2560 def ivHeightWidthSmallerZero(**kwargs):
2561 opName = kwargs['opName']
2562
2563 inputShapes = kwargs['shapeList']
2564 input = inputShapes[0]
2565 if not opName.endswith("pool2d"):
2566 filter = inputShapes[1]
2567
2568 args = kwargs['args']
2569 strides = args[0]
2570 padding = args[1]
2571 dilations = args[2]
2572 if opName.endswith("pool2d"):
2573 kernel = args[2]
2574
2575 if opName.startswith('conv2d'):
2576 h = (
2577 input[1]
2578 - filter[1]
2579 - (filter[1] - 1) * (dilations[0] - 1)
2580 + padding[0]
2581 + padding[1]
2582 ) // strides[0] + 1
2583
2584 w = (
2585 input[2]
2586 - filter[2]
2587 - (filter[2] - 1) * (dilations[1] - 1)
2588 + padding[2]
2589 + padding[3]
2590 ) // strides[1] + 1
2591 elif opName.startswith("depthwise_conv2d"):
2592 h = (
2593 input[1]
2594 - filter[0]
2595 - (filter[0] - 1) * (dilations[0] - 1)
2596 + padding[0]
2597 + padding[1]
2598 ) // strides[0] + 1
2599
2600 w = (
2601 input[2]
2602 - filter[1]
2603 - (filter[1] - 1) * (dilations[1] - 1)
2604 + padding[2]
2605 + padding[3]
2606 ) // strides[1] + 1
2607 elif opName.endswith("pool2d"):
2608 h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
2609 w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
2610 else:
2611 assert False, "Unrecognized Op"
2612
2613 if h <= 0 or w <= 0:
2614 # Invalid parameter combination
2615 return True
2616 return False
2617
2618 @staticmethod
2619 def ivNonPositiveOutputShape(**kwargs):
2620 args = kwargs['args']
2621 output_shape = args[3]
2622 if output_shape[1] <= 0 or output_shape[2] <= 0:
2623 # Negative output shape
2624 return True
2625 return False
2626
2627
Kevin Cheng550ccc52021-03-03 11:21:43 -08002628
Eric Kunzee5e26762020-10-13 16:11:07 -07002629class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002630 # Maximum rank of tensor supported by test generator.
2631 TOSA_TENSOR_MAX_RANK = 6
2632
Eric Kunzee5e26762020-10-13 16:11:07 -07002633 def __init__(self, args):
2634 self.args = args
2635 self.basePath = args.output_dir
2636 self.random_seed = args.random_seed
2637 self.ser = None
2638 self.rng = np.random.default_rng(self.random_seed)
2639 self.createDynamicOpLists()
2640 self.initOpListDefaults()
2641 self.quantGen = TosaQuantGen()
2642 # Force makeShape to do a specific starting shape
2643 self.targetted_shape = None
2644
2645 def createSerializer(self, opName, testPath):
2646 self.testPath = os.path.join(opName, testPath)
2647
2648 fullPath = os.path.join(self.basePath, self.testPath)
2649 os.makedirs(fullPath, exist_ok=True)
2650 self.ser = ts.TosaSerializer(fullPath)
2651
2652 def getSerializer(self):
2653 return self.ser
2654
2655 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002656 with open(
2657 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
2658 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07002659 fd.write(self.ser.serialize())
2660
Kevin Cheng550ccc52021-03-03 11:21:43 -08002661 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
2662 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07002663
Matthew Haddon74567092021-07-16 15:38:20 +01002664 def resetRNG(self, seed=None):
2665 if seed == None:
2666 seed = self.random_seed + 1
2667 self.rng = np.random.default_rng(seed)
2668
Eric Kunzee5e26762020-10-13 16:11:07 -07002669 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07002670 if dtype == DType.BOOL:
2671 np_dt = np.bool
2672 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07002673 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07002674 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07002675 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002676 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002677 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
2678 elif dtype == DType.UINT8:
2679 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002680 elif dtype == DType.INT16:
2681 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
2682 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002683 return np.int32(
2684 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
2685 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002686 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002687 return np.int64(
2688 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
2689 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002690 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002691 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002692 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002693 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002694
Kevin Cheng989cb052021-04-28 16:29:44 -07002695 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07002696 placeholders = []
2697
Kevin Cheng989cb052021-04-28 16:29:44 -07002698 assert len(shape_list) == len(dtype_list)
2699
2700 for idx, shape in enumerate(shape_list):
2701 arr = self.getRandTensor(shape, dtype_list[idx])
2702 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002703
2704 return placeholders
2705
Kevin Cheng989cb052021-04-28 16:29:44 -07002706 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07002707 consts = []
2708
Kevin Cheng989cb052021-04-28 16:29:44 -07002709 assert len(shape_list) == len(dtype_list)
2710
2711 for idx, shape in enumerate(shape_list):
2712 arr = self.getRandTensor(shape, dtype_list[idx])
2713 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002714
2715 return consts
2716
2717 def makeShape(self, rank):
2718 if self.targetted_shape:
2719 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002720 return np.int32(
2721 self.rng.integers(
2722 low=self.args.tensor_shape_range[0],
2723 high=self.args.tensor_shape_range[1],
2724 size=rank,
2725 )
2726 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002727
2728 def setTargetShape(self, shape):
2729 self.targetted_shape = shape
2730
2731 def randInt(self, low=0, high=256):
2732 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
2733
2734 def getRandNumberDType(self, dtype):
2735 if dtype == DType.FLOAT:
2736 return self.rng.random()
2737 elif dtype == DType.BOOL:
2738 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07002739 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07002740 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07002741 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07002742 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002743 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07002744 elif dtype == DType.INT16:
2745 low, high = (-32768, 32768)
2746 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002747 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07002748 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002749 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07002750 # Special size
2751 return np.int64(self.rng.integers(low, high, size=1))[0]
2752 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002753 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002754
2755 return np.int32(self.rng.integers(low, high, size=1))[0]
2756
2757 def shapeStr(self, shape):
2758
2759 sStr = []
2760 # Convert to strings
2761 for i in shape:
2762 sStr.append(str(i))
2763
Kevin Cheng550ccc52021-03-03 11:21:43 -08002764 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002765
2766 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07002767 if isinstance(t, list):
2768 assert len(t) >= 2
2769 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002770 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002771 if t == DType.BOOL:
2772 return "b"
2773 elif t == DType.INT4:
2774 return "i4"
2775 elif t == DType.INT8:
2776 return "i8"
2777 elif t == DType.UINT8:
2778 return "u8"
2779 elif t == DType.INT16:
2780 return "i16"
2781 elif t == DType.INT32:
2782 return "i32"
2783 elif t == DType.INT48:
2784 return "i48"
2785 elif t == DType.FLOAT:
2786 return "float"
2787 else:
2788 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002789
2790 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002791 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08002792 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07002793 return 4
2794 elif t == DType.INT8:
2795 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08002796 elif t == DType.UINT8:
2797 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07002798 elif t == DType.INT16:
2799 return 16
2800 elif t == DType.INT32:
2801 return 32
2802 elif t == DType.INT48:
2803 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +01002804 elif t == DType.FLOAT:
2805 return 32
2806 elif t == DType.BOOL:
2807 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002808 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002809 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002810
2811 # Argument generators
2812 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
2813 # Where the string descriptor is used to generate the test name and
2814 # The build_fcn_arg_list is expanded and passed to the operator test
2815 # build function
2816
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002817 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
2818 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
2819
Matthew Haddon848efb42021-09-09 12:30:53 +01002820 # build_placeholder returns an int, ABS/other ops does not
2821 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002822 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
2823 return result_tens
2824 elif op['op'] == Op.IDENTITY:
2825 self.ser.addOperator(op['op'], a.name, result_tens.name, None, qinfo)
2826 return result_tens
2827
2828 # Ensure new output type has correct qinfo
2829 if error_name == ErrorIf.WrongOutputType:
2830 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
2831 qinfo = ts.TosaSerializerQuantInfo()
2832 qinfo.UnaryQuantInfo(
2833 TosaQuantGen.getQinfo(self, a.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
2834 )
2835
2836 # Invalidate Input/Output list for error if checks.
2837 input_list = [a.name]
2838 output_list = [result_tens.name]
2839 pCount, cCount = op["operands"]
2840 num_operands = pCount + cCount
2841 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2842
2843 TosaErrorValidator.evValidateErrorIfs(
2844 self.ser,
2845 validator_fcns,
2846 error_name,
2847 op=op,
2848 input_dtype=a.dtype,
2849 output_dtype=result_tens.dtype,
2850 qinfo = qinfo,
2851 result_tensor = result_tens,
2852 input_list=input_list,
2853 output_list=output_list,
2854 num_operands=num_operands,
2855 )
2856
2857 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002858 return result_tens
2859
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002860 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
2861 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
2862
2863
2864 # Invalidate Input/Output list for error if checks.
2865 input_list = [a.name, b.name]
2866 output_list = [result_tens.name]
2867 pCount, cCount = op["operands"]
2868 num_operands = pCount + cCount
2869 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2870
2871 TosaErrorValidator.evValidateErrorIfs(
2872 self.ser,
2873 validator_fcns,
2874 error_name,
2875 op=op,
2876 input1 = a,
2877 input2 = b,
2878 input_dtype = a.dtype,
2879 output_dtype = result_tens.dtype,
2880 result_tensor = result_tens,
2881 input_list=input_list,
2882 output_list=output_list,
2883 num_operands=num_operands,
2884 )
2885
2886 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07002887 return result_tens
2888
2889 def build_binary_nonbroadcast(self, op, a, b):
2890 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002891 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002892 return result_tens
2893
Kevin Chengaee1fac2020-11-11 13:54:06 -08002894 def build_arithmetic_right_shift(self, op, a, b, round):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002895 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002896
2897 attr = ts.TosaSerializerAttribute()
2898 attr.ArithmeticRightShiftAttribute(round)
2899
Matthew Haddon848efb42021-09-09 12:30:53 +01002900 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002901 return result_tens
2902
2903 def build_mul(self, op, a, b, shift):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002904 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Eric Kunzee5e26762020-10-13 16:11:07 -07002905
2906 # Special for multiply:
2907 # Force the result to INT32 for INT types
2908 if a.dtype != DType.FLOAT:
2909 result_tens.setDtype(DType.INT32)
2910
Kevin Chengaee1fac2020-11-11 13:54:06 -08002911 attr = ts.TosaSerializerAttribute()
2912 attr.MulAttribute(shift)
2913
Matthew Haddon848efb42021-09-09 12:30:53 +01002914 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002915 return result_tens
2916
2917 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002918 # Constant size depending on type, random values
2919 if a.dtype == DType.INT16:
Kevin Chengacb550f2021-06-29 15:32:19 -07002920 table_dtype = DType.INT16
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002921 table_arr = self.getRandTensor([513], table_dtype)
2922 else:
2923 assert a.dtype == DType.INT8
2924 table_dtype = DType.INT8
2925 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002926
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002927 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
2928 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01002929 self.ser.addOperator(op['op'], [a.name, table_tens.name], [result_tens.name], None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002930
2931 return result_tens
2932
2933 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07002934 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002935 self.ser.addOperator(op['op'], [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002936 return result_tens
2937
2938 def build_comparison(self, op, a, b):
2939 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002940 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002941 return result_tens
2942
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002943 def build_argmax(self, op, a, axis, validator_fcns, error_name):
2944 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
2945
2946 # Invalidate Input/Output list for error if checks.
2947 input_list = [a.name]
2948 output_list = [result_tens.name]
2949 pCount, cCount = op["operands"]
2950 num_operands = pCount + cCount
2951 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2952
2953 TosaErrorValidator.evValidateErrorIfs(
2954 self.ser,
2955 validator_fcns,
2956 error_name,
2957 op=op,
2958 axis=axis,
2959 input_shape = a.shape,
2960 input_dtype = a.dtype,
2961 output_shape = result_tens.shape,
2962 output_dtype = result_tens.dtype,
2963 result_tensor = result_tens,
2964 input_list=input_list,
2965 output_list=output_list,
2966 num_operands=num_operands,
2967 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002968
2969 attr = ts.TosaSerializerAttribute()
2970 attr.AxisAttribute(axis)
2971
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002972 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002973 return result_tens
2974
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002975 def build_pool2d(self, op, input, stride, pad, kernel, validator_fcns=None, error_name=None, qinfo=None):
2976 result_tens = OutputShaper.pool2dOp(self.ser, self.rng, input, kernel, stride, pad, error_name)
2977
2978 # Ensure new output type has correct qinfo
2979 if error_name == ErrorIf.WrongInputType:
2980 if input.dtype not in [DType.INT8, DType.UINT8]:
2981 qinfo = ts.TosaSerializerQuantInfo()
2982 qinfo.UnaryQuantInfo(
2983 TosaQuantGen.getQinfo(self, input.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
2984 )
2985
2986 # Invalidate Input/Output list for error if checks.
2987 input_list = [input.name]
2988 output_list = [result_tens.name]
2989 pCount, cCount = op["operands"]
2990 num_operands = pCount + cCount
2991 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2992
2993 TosaErrorValidator.evValidateErrorIfs(
2994 self.ser,
2995 validator_fcns,
2996 error_name,
2997 op=op,
2998 input_shape=input.shape,
2999 input_dtype=input.dtype,
3000 output_shape=result_tens.shape,
3001 output_dtype=result_tens.dtype,
3002 kernel=kernel,
3003 stride=stride,
3004 pad=pad,
3005 qinfo = qinfo,
3006 result_tensor = result_tens,
3007 input_list=input_list,
3008 output_list=output_list,
3009 num_operands=num_operands,
3010 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003011
3012 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003013 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07003014
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003015 self.ser.addOperator(op['op'], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003016 return result_tens
3017
3018 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003019 assert len(padding) == 4
3020 result_tens = OutputShaper.conv2dOp(
3021 self.ser, ifm, filter, strides, padding, dilations
3022 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003023
3024 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003025 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07003026
Kevin Cheng550ccc52021-03-03 11:21:43 -08003027 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003028 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003029 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003030 return result_tens
3031
Kevin Cheng1533b852021-09-01 12:51:58 -07003032 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
3033 assert len(padding) == 6
3034 result_tens = OutputShaper.conv3dOp(
3035 self.ser, ifm, filter, strides, padding, dilations
3036 )
3037
3038 attr = ts.TosaSerializerAttribute()
3039 attr.ConvAttribute(padding, strides, dilations)
3040
3041 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003042 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng1533b852021-09-01 12:51:58 -07003043 )
3044 return result_tens
3045
Kevin Cheng550ccc52021-03-03 11:21:43 -08003046 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07003047 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003048 ):
3049 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07003050 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
3051
3052 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003053 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003054
Kevin Cheng550ccc52021-03-03 11:21:43 -08003055 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003056 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003057 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003058 return result_tens
3059
Kevin Cheng550ccc52021-03-03 11:21:43 -08003060 def build_depthwise_conv2d(
3061 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
3062 ):
3063 result_tens = OutputShaper.depthwiseConv2dOp(
3064 self.ser, ifm, filter, strides, padding, dilations
3065 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003066
3067 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003068 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07003069
Kevin Cheng550ccc52021-03-03 11:21:43 -08003070 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003071 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003072 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003073 return result_tens
3074
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003075 def build_fully_connected(self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None):
3076 result_tens = OutputShaper.fullyConnectedOp(self.ser, self.rng, ifm, filter, error_name)
3077
3078 # Invalidate Input/Output list for error if checks.
3079 input_list = [ifm.name, filter.name, bias.name]
3080 output_list = [result_tens.name]
3081 pCount, cCount = op["operands"]
3082 num_operands = pCount + cCount
3083 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3084
3085 TosaErrorValidator.evValidateErrorIfs(
3086 self.ser,
3087 validator_fcns,
3088 error_name,
3089 op=op,
3090 input_shape=ifm.shape,
3091 input_dtype=ifm.dtype,
3092 weight_dtype=filter.dtype,
3093 output_shape=result_tens.shape,
3094 output_dtype=result_tens.dtype,
3095 qinfo = qinfo,
3096 result_tensor = result_tens,
3097 input_list=input_list,
3098 output_list=output_list,
3099 num_operands=num_operands,
3100 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003101
Kevin Cheng550ccc52021-03-03 11:21:43 -08003102 self.ser.addOperator(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003103 op['op'], input_list, output_list, None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003104 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003105 return result_tens
3106
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003107 def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
3108 result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
3109
3110 # Invalidate Input/Output list for error if checks.
3111 input_list = [a.name, b.name]
3112 output_list = [result_tens.name]
3113 pCount, cCount = op["operands"]
3114 num_operands = pCount + cCount
3115 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3116
3117 TosaErrorValidator.evValidateErrorIfs(
3118 self.ser,
3119 validator_fcns,
3120 error_name,
3121 op=op,
3122 input_shape=a.shape,
3123 input_dtype=a.dtype,
3124 input2_shape=b.shape,
3125 input2_dtype=b.dtype,
3126 output_shape=result_tens.shape,
3127 output_dtype=result_tens.dtype,
3128 qinfo = qinfo,
3129 result_tensor = result_tens,
3130 input_list=input_list,
3131 output_list=output_list,
3132 num_operands=num_operands,
3133 )
3134
3135 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003136 return result_tens
3137
Matthew Haddond6ce7252021-09-29 15:35:44 +01003138 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
3139 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
3140
3141 # Invalidate Input/Output list for error if checks.
3142 input_list = [a.name]
3143 output_list = [result_tens.name]
3144 pCount, cCount = op["operands"]
3145 num_operands = pCount + cCount
3146 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3147
3148 TosaErrorValidator.evValidateErrorIfs(
3149 self.ser,
3150 validator_fcns,
3151 error_name,
3152 op=op,
3153 axis = axis,
3154 input_shape = a.shape,
3155 output_shape = result_tens.shape,
3156 input_dtype = a.dtype,
3157 output_dtype = result_tens.dtype,
3158 result_tensor = result_tens,
3159 input_list=input_list,
3160 output_list=output_list,
3161 num_operands=num_operands,
3162 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003163
3164 attr = ts.TosaSerializerAttribute()
3165 attr.AxisAttribute(axis)
3166
Matthew Haddond6ce7252021-09-29 15:35:44 +01003167 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003168 return result_tens
3169
3170 def build_clamp(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003171 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003172
3173 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01003174 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07003175
3176 if a.dtype == DType.FLOAT:
3177 attr.ClampAttribute(0, 0, min(v), max(v))
3178 else:
3179 attr.ClampAttribute(min(v), max(v), 0, 0)
3180
Matthew Haddon848efb42021-09-09 12:30:53 +01003181 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003182 return result_tens
3183
3184 def build_leaky_relu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003185 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003186 attr = ts.TosaSerializerAttribute()
3187
3188 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
3189
Matthew Haddon848efb42021-09-09 12:30:53 +01003190 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003191 return result_tens
3192
3193 # Needs an additional type/input
3194 def build_prelu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003195 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003196
Matthew Haddon848efb42021-09-09 12:30:53 +01003197 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003198 return result_tens
3199
Eric Kunzee5e26762020-10-13 16:11:07 -07003200 def build_sigmoid(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003201 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01003202 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003203 return result_tens
3204
3205 def build_tanh(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003206 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01003207 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003208 return result_tens
3209
Matthew Haddon818ab902021-07-27 09:12:49 +01003210 def build_concat(self, op, *a):
Kevin Cheng93a16282021-08-31 16:14:03 -07003211 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01003212
3213 # To store variable length list of input tensors we need to store axis along with it
3214 axis = a[-1]
3215 a = a[:-1]
3216
3217 result_tens = OutputShaper.concatOp(self.ser, axis, *a)
Eric Kunzee5e26762020-10-13 16:11:07 -07003218
3219 attr = ts.TosaSerializerAttribute()
3220 attr.AxisAttribute(axis)
3221
Matthew Haddon818ab902021-07-27 09:12:49 +01003222 input_tensor_names = []
3223 for tensor in a:
3224 input_tensor_names.append(tensor.name)
3225
Matthew Haddon848efb42021-09-09 12:30:53 +01003226 self.ser.addOperator(op['op'], input_tensor_names, [result_tens.name], attr)
3227 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07003228
Matthew Haddone807aae2021-10-11 18:12:58 +01003229 def build_pad(self, op, a, padding, validator_fcns=None, error_name=None, qinfo=None):
3230 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003231
3232 # Need to turn the padding array into a TOSA tensor here.
3233 # This is one of the few tensor operands that does not get
3234 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08003235 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07003236
Matthew Haddone807aae2021-10-11 18:12:58 +01003237 # Invalidate Input/Output list for error if checks.
3238 input_list = [a.name, padding_tens.name]
3239 output_list = [result_tens.name]
3240 pCount, cCount = op["operands"]
3241 num_operands = pCount + cCount
3242 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3243
3244 TosaErrorValidator.evValidateErrorIfs(
3245 self.ser,
3246 validator_fcns,
3247 error_name,
3248 op=op,
3249 input_shape = a.shape,
3250 output_shape = result_tens.shape,
3251 input_dtype = a.dtype,
3252 output_dtype = result_tens.dtype,
3253 pad=padding,
3254 qinfo=qinfo,
3255 result_tensor = result_tens,
3256 input_list=input_list,
3257 output_list=output_list,
3258 num_operands=num_operands,
3259 )
3260
Kevin Cheng550ccc52021-03-03 11:21:43 -08003261 self.ser.addOperator(
Matthew Haddone807aae2021-10-11 18:12:58 +01003262 op['op'], input_list, output_list, None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003263 )
Matthew Haddone86fd342021-09-07 16:12:21 +01003264 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07003265
Matthew Haddone807aae2021-10-11 18:12:58 +01003266 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
3267 result_tens = OutputShaper.reshapeOp(self.ser, self.rng, a, newShape, error_name)
3268
3269 # Invalidate Input/Output list for error if checks.
3270 input_list = [a.name]
3271 output_list = [result_tens.name]
3272 pCount, cCount = op["operands"]
3273 num_operands = pCount + cCount
3274 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3275
3276 TosaErrorValidator.evValidateErrorIfs(
3277 self.ser,
3278 validator_fcns,
3279 error_name,
3280 op=op,
3281 input_shape = a.shape,
3282 output_shape = result_tens.shape,
3283 input_dtype = a.dtype,
3284 output_dtype = result_tens.dtype,
3285 result_tensor = result_tens,
3286 input_list=input_list,
3287 output_list=output_list,
3288 num_operands=num_operands,
3289 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003290
3291 attr = ts.TosaSerializerAttribute()
3292 attr.ReshapeAttribute(newShape)
3293
Matthew Haddone807aae2021-10-11 18:12:58 +01003294 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003295 return result_tens
3296
3297 def build_reverse(self, op, a, axis):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003298 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003299
3300 attr = ts.TosaSerializerAttribute()
3301 attr.AxisAttribute(axis)
3302
Matthew Haddon848efb42021-09-09 12:30:53 +01003303 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003304 return result_tens
3305
Matthew Haddone807aae2021-10-11 18:12:58 +01003306 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
3307 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003308
Kevin Cheng550ccc52021-03-03 11:21:43 -08003309 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07003310
Matthew Haddone807aae2021-10-11 18:12:58 +01003311 # Invalidate Input/Output list for error if checks.
3312 input_list = [a.name, perms_tens.name]
3313 output_list = [result_tens.name]
3314 pCount, cCount = op["operands"]
3315 num_operands = pCount + cCount
3316 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3317
3318 TosaErrorValidator.evValidateErrorIfs(
3319 self.ser,
3320 validator_fcns,
3321 error_name,
3322 op=op,
3323 input_shape = a.shape,
3324 output_shape = result_tens.shape,
3325 perms=perms,
3326 input_dtype = a.dtype,
3327 output_dtype = result_tens.dtype,
3328 result_tensor = result_tens,
3329 input_list=input_list,
3330 output_list=output_list,
3331 num_operands=num_operands,
3332 )
3333
3334
3335 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07003336 return result_tens
3337
Matthew Haddone807aae2021-10-11 18:12:58 +01003338 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
3339 result_tens = OutputShaper.sliceOp(self.ser, self.rng, a, start, size, error_name)
3340
3341 # Invalidate Input/Output list for error if checks.
3342 input_list = [a.name]
3343 output_list = [result_tens.name]
3344 pCount, cCount = op["operands"]
3345 num_operands = pCount + cCount
3346 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3347
3348 TosaErrorValidator.evValidateErrorIfs(
3349 self.ser,
3350 validator_fcns,
3351 error_name,
3352 op=op,
3353 input_shape = a.shape,
3354 output_shape = result_tens.shape,
3355 input_dtype = a.dtype,
3356 output_dtype = result_tens.dtype,
3357 start=start,
3358 size=size,
3359 result_tensor = result_tens,
3360 input_list=input_list,
3361 output_list=output_list,
3362 num_operands=num_operands,
3363 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003364
3365 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01003366 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07003367
Matthew Haddone807aae2021-10-11 18:12:58 +01003368 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003369 return result_tens
3370
3371 def build_tile(self, op, a, multiples):
3372 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
3373
3374 attr = ts.TosaSerializerAttribute()
3375 attr.TileAttribute(multiples)
3376
Matthew Haddon848efb42021-09-09 12:30:53 +01003377 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003378 return result_tens
3379
Kevin Cheng77d0f762020-11-24 10:26:32 -08003380 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07003381
3382 # Create a new indicies tensor
3383 # here with data that doesn't exceed the dimensions of the values tensor
3384
Kevin Cheng550ccc52021-03-03 11:21:43 -08003385 K = values.shape[1] # K
3386 W = self.randInt(
3387 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
3388 ) # W
3389 indicies_arr = np.int32(
3390 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
3391 ) # (N, W)
3392 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003393
Kevin Cheng77d0f762020-11-24 10:26:32 -08003394 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07003395
Matthew Haddon848efb42021-09-09 12:30:53 +01003396 self.ser.addOperator(op['op'], [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003397
3398 return result_tens
3399
Kevin Cheng77d0f762020-11-24 10:26:32 -08003400 def build_scatter(self, op, values_in, input):
3401
3402 # Create a new indicies tensor
3403 # here with data that doesn't exceed the dimensions of the values_in tensor
3404
Kevin Cheng550ccc52021-03-03 11:21:43 -08003405 K = values_in.shape[1] # K
3406 W = input.shape[1] # W
3407 indicies_arr = np.int32(
3408 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
3409 ) # (N, W)
3410 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08003411
3412 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
3413
Kevin Cheng550ccc52021-03-03 11:21:43 -08003414 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003415 op['op'], [values_in.name, indicies.name, input.name], [result_tens.name]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003416 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08003417
3418 return result_tens
3419
Matthew Haddon848efb42021-09-09 12:30:53 +01003420
Kevin Cheng550ccc52021-03-03 11:21:43 -08003421 def build_resize(
3422 self,
3423 op,
3424 input,
3425 mode,
3426 stride,
3427 offset,
3428 shift,
3429 stride_fp,
3430 offset_fp,
3431 output_dims,
3432 input_dtype,
3433 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01003434 validator_fcns,
3435 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003436 ):
3437 result_tens = OutputShaper.resizeOp(
3438 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003439 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003440 input,
3441 mode,
3442 stride,
3443 offset,
3444 shift,
3445 stride_fp,
3446 offset_fp,
3447 output_dims,
3448 input_dtype,
3449 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01003450 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08003451 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003452
Matthew Haddon848efb42021-09-09 12:30:53 +01003453 # Invalidate Input/Output list for error if checks.
3454 input_list = [input.name]
3455 output_list = [result_tens.name]
3456 pCount, cCount = op["operands"]
3457 num_operands = pCount + cCount
3458 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
Matthew Haddone86fd342021-09-07 16:12:21 +01003459
Matthew Haddon848efb42021-09-09 12:30:53 +01003460 TosaErrorValidator.evValidateErrorIfs(
3461 self.ser,
3462 validator_fcns,
3463 error_name,
3464 op=op,
3465 mode=mode,
3466 shift=shift,
3467 input_dtype=input_dtype,
3468 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003469 input_shape=input.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01003470 output_shape=output_dims,
3471 offset=offset,
3472 offset_fp=offset_fp,
3473 stride=stride,
3474 stride_fp=stride_fp,
3475 input_list=input_list,
3476 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003477 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01003478 num_operands=num_operands,
3479 )
Matthew Haddone86fd342021-09-07 16:12:21 +01003480
Eric Kunzee5e26762020-10-13 16:11:07 -07003481 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08003482
Kevin Cheng550ccc52021-03-03 11:21:43 -08003483 attr.ResizeAttribute(
3484 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
3485 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003486
Matthew Haddon848efb42021-09-09 12:30:53 +01003487 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003488 return result_tens
3489
3490 def build_identityn(self, op, val, val2):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003491 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, None)
3492 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, None)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003493 self.ser.addOperator(
3494 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
3495 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003496 return result_tens
3497
Kevin Cheng17e92022021-10-01 14:33:33 -07003498 def build_const(self, op, val):
3499 self.ser.addOutputTensor(val)
3500 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07003501
3502 # Type Conversion
3503 def build_cast(self, op, val, out_dtype):
3504 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01003505 self.ser.addOperator(op['op'], [val.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003506 return result_tens
3507
Matthew Haddonc2025212021-10-08 21:21:05 +01003508 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel, validator_fcns, error_name):
Eric Kunzee5e26762020-10-13 16:11:07 -07003509 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
3510
3511 if per_channel:
3512 nc = val.shape[-1]
3513 else:
3514 nc = 1
3515
3516 in_type_width = self.typeWidth(val.dtype)
3517 out_type_width = self.typeWidth(out_dtype)
3518
Kevin Cheng3a478572021-01-22 17:21:02 -08003519 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003520 input_zp = self.randInt(-128, 128)
3521 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07003522 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003523 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07003524 in_type_width = in_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01003525 elif error_name == ErrorIf.InputZeroPointNotZero:
3526 input_zp = self.randInt(-128, 128)
3527 if input_zp == 0:
3528 input_zp = input_zp + self.rng.integers(1, 10)
3529 in_type_width = in_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003530 else:
3531 input_zp = 0
3532
Kevin Cheng3a478572021-01-22 17:21:02 -08003533 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003534 output_zp = self.randInt(-128, 128)
3535 out_type_width = out_type_width + 1
3536 elif out_dtype == DType.UINT8:
3537 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07003538 out_type_width = out_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01003539 elif error_name == ErrorIf.OutputZeroPointNotZero:
3540 output_zp = self.randInt(-128, 128)
3541 if output_zp == 0:
3542 output_zp = output_zp + self.rng.integers(1, 10)
3543 out_type_width = out_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003544 else:
3545 output_zp = 0
3546
3547 # Calculate scale based on:
3548 # scale = a *(2^output_width)/(2^input_width))
3549
3550 a = np.float32(self.rng.random(size=[nc]))
3551 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
3552
3553 if scale32:
3554 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01003555 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07003556 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
3557 else:
3558 # Cap the scaling at 2^15 - 1 for scale16
3559 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
3560
Kevin Cheng550ccc52021-03-03 11:21:43 -08003561 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07003562
3563 multiplier_arr = np.int32(np.zeros(shape=[nc]))
3564 shift_arr = np.int32(np.zeros(shape=[nc]))
3565
3566 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003567 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
3568 scale_arr[i], scale32
3569 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003570
Kevin Cheng550ccc52021-03-03 11:21:43 -08003571 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07003572
Matthew Haddonc2025212021-10-08 21:21:05 +01003573 # Invalidate Input/Output list for error if checks.
3574 input_list = [val.name]
3575 output_list = [result_tens.name]
3576 pCount, cCount = op["operands"]
3577 num_operands = pCount + cCount
3578 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3579
3580 qinfo = (input_zp, output_zp)
3581 TosaErrorValidator.evValidateErrorIfs(
3582 self.ser,
3583 validator_fcns,
3584 error_name,
3585 op=op,
3586 input_dtype=val.dtype,
3587 output_dtype=out_dtype,
3588 input_shape=val.shape,
3589 qinfo=qinfo,
3590 scale32 = scale32,
3591 double_round = double_round,
3592 input_list=input_list,
3593 output_list=output_list,
3594 result_tensor=result_tens,
3595 num_operands=num_operands,
3596 )
3597
Eric Kunzee5e26762020-10-13 16:11:07 -07003598 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003599 attr.RescaleAttribute(
3600 input_zp,
3601 output_zp,
3602 multiplier_arr,
3603 shift_arr,
3604 scale32,
3605 double_round,
3606 per_channel,
3607 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003608
Matthew Haddonc2025212021-10-08 21:21:05 +01003609 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003610 return result_tens
3611
3612 def build_cond_if_const(self, op, then_tens, else_tens, cond):
3613 # For cond_if with constants, we're supplied with then/else tensors that we ignore
3614 # (except for the generated shap) and the condition. Build Then/Else blocks
3615 # and fill them with const nodes for the body.
3616
3617 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003618 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07003619
3620 # Make then/else tensors
3621 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01003622 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
3623 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003624
3625 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08003626 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07003627
3628 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003629 then_block = "THEN_BLOCK"
3630 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003631 attr = ts.TosaSerializerAttribute()
3632 attr.CondIfAttribute(then_block, else_block)
3633
3634 # Finally, build the op and the two blocks
Matthew Haddon848efb42021-09-09 12:30:53 +01003635 self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003636
3637 self.ser.startBasicBlock(then_block)
3638 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003639 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003640 self.ser.addOutputTensor(then_tens)
3641
3642 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003643 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003644 self.ser.addOutputTensor(else_tens)
3645
3646 return result_tens
3647
3648 def build_cond_if_binary(self, op, a, b, cond):
3649 # For cond_if with a binary op in the then/else blocks, take a and b and
3650 # alternately add or subtract them based on the condition
3651
3652 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003653 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07003654
Kevin Cheng550ccc52021-03-03 11:21:43 -08003655 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003656
3657 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003658 then_block = "THEN_BLOCK"
3659 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003660 attr = ts.TosaSerializerAttribute()
3661 attr.CondIfAttribute(then_block, else_block)
3662
3663 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003664 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003665 op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08003666 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003667
Les Bell6040b4d2021-10-11 12:50:31 +01003668 if a.dtype in (DType.FLOAT, DType.INT32):
3669 then_op, else_op = Op.ADD, Op.SUB
3670 elif a.dtype in (DType.INT8, DType.INT16):
3671 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
3672 else:
3673 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07003674
Les Bell6040b4d2021-10-11 12:50:31 +01003675 for block, op in ((then_block, then_op), (else_block, else_op)):
3676 self.ser.startBasicBlock(block)
3677 self.ser.addInputTensor(a)
3678 self.ser.addInputTensor(b)
3679 tens = self.ser.addOutput(a.shape, a.dtype)
3680 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003681
3682 return result_tens
3683
3684 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003685 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07003686
Kevin Cheng550ccc52021-03-03 11:21:43 -08003687 cond_block = "COND_BLOCK"
3688 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003689
3690 attr = ts.TosaSerializerAttribute()
3691 attr.WhileLoopAttribute(cond_block, body_block)
3692
3693 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003694 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003695 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08003696 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07003697
3698 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003699 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
3700 a_out = self.ser.addIntermediate(a.shape, a.dtype)
3701 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003702
3703 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08003704 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003705 op['op'],
Kevin Cheng550ccc52021-03-03 11:21:43 -08003706 [iter.name, a.name, acc.name],
3707 [iter_out.name, a_out.name, acc_out.name],
3708 attr,
3709 )
Kevin Chengb227ae52021-09-02 13:43:17 -07003710 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07003711
3712 # COND block (input: iter, output: cond_tens )
3713 self.ser.startBasicBlock(cond_block)
3714 self.ser.addInputTensor(iter)
3715 self.ser.addInputTensor(a)
3716 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003717 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
3718 cond_tens = self.ser.addOutput([], DType.BOOL)
3719 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003720
3721 # BODY block (input: a, acc, iter, output: a, acc, iter)
3722 # Note that local intermediate tensors need to be declared here for the outputs
3723 self.ser.startBasicBlock(body_block)
3724 self.ser.addInputTensor(iter)
3725 self.ser.addInputTensor(a)
3726 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003727 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
3728 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
3729 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003730 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
3731 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
3732 self.ser.addOutputTensor(iter_body_out)
3733 self.ser.addOutputTensor(a)
3734 self.ser.addOutputTensor(acc_body_out)
3735
3736 return acc_out
3737
Matthew Haddon1c00b712021-10-01 15:51:03 +01003738 def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
3739 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
3740 default_test_rank_range = range(1, 5)
3741 if not shapeFilter:
3742 shapeFilter = [None]
3743
3744 # Calculate the filters based on what is requested and what the operator allows
3745 rmin, rmax = op["rank"]
3746 if rankFilter is not None:
3747 cleanRankFilter = []
3748 # Ensure rankFilter values are allowed by operator
3749 for rank in rankFilter:
3750 if rank >= rmin and rank <= rmax:
3751 cleanRankFilter.append(rank)
3752 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01003753 # Ensure default behaviour is bounded by default range or by operator,
3754 # whichever is the smaller range of ranks.
3755 opRankRange = range(rmin, rmax + 1)
3756 cleanRankFilter = opRankRange if len(opRankRange) <= len(default_test_rank_range) else default_test_rank_range
Matthew Haddon1c00b712021-10-01 15:51:03 +01003757 else:
3758 cleanRankFilter = range(rmin, rmax + 1)
3759
3760 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003761
Matthew Haddon1c00b712021-10-01 15:51:03 +01003762 if dtypeFilter is not None:
3763 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01003764 # Create list of operator dtypes filtered by requested dtypes
3765 for dtype in dtypes:
3766 if dtype in dtypeFilter or (isinstance(dtype, list) and dtype[0] in dtypeFilter):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003767 cleanDtypeFilter.append(dtype)
3768 else:
3769 cleanDtypeFilter = dtypes
3770
3771 if testType == 'positive':
3772 filterDict = {
3773 'shapeFilter': shapeFilter,
3774 'rankFilter': cleanRankFilter,
3775 'dtypeFilter': cleanDtypeFilter
3776 }
3777 return filterDict
3778 elif testType == 'negative':
Matthew Haddone807aae2021-10-11 18:12:58 +01003779 if validator is not None:
3780 validator_info = validator(check=False, op=op)
3781 else:
3782 return None
3783
Matthew Haddon1c00b712021-10-01 15:51:03 +01003784 error_arguments = validator_info['param_reqs']
3785
3786 #Set parameters as required
3787 if error_arguments['rank'] != None:
3788 rankFilter = error_arguments['rank']
3789 else:
3790 rankFilter = cleanRankFilter
3791
3792 if error_arguments['dtype'] != None:
3793 dtypeFilter = error_arguments['dtype']
3794 else:
3795 dtypeFilter = cleanDtypeFilter
3796
3797 if error_arguments['shape'] != None:
3798 shapeFilter = error_arguments['shape']
3799 else:
3800 shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small
3801
3802 filterDict = {
3803 'shapeFilter': shapeFilter,
3804 'rankFilter': rankFilter,
3805 'dtypeFilter': dtypeFilter
3806 }
3807 return filterDict
3808
3809
Kevin Cheng550ccc52021-03-03 11:21:43 -08003810 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01003811 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08003812 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003813
3814 try:
3815 op = self.TOSA_OP_LIST[opName]
3816 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003817 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003818
3819 # Initialize a new random number generator
3820 self.rng = np.random.default_rng(self.random_seed)
3821
Kevin Cheng550ccc52021-03-03 11:21:43 -08003822 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003823
Eric Kunzee5e26762020-10-13 16:11:07 -07003824 # Test list consists of a tuple of:
3825 # (opName, testNameStr, dtype, shapeList, argumentsList)
3826 testList = []
Matthew Haddon1c00b712021-10-01 15:51:03 +01003827 if testType == 'negative' and "error_if_validators" in op:
3828 error_if_validators = op["error_if_validators"]
3829 else:
3830 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07003831
Matthew Haddon1c00b712021-10-01 15:51:03 +01003832 for validator in error_if_validators:
3833 if validator is not None:
3834 error_name = validator(check=False, op=op)['error_name']
Matthew Haddon1c00b712021-10-01 15:51:03 +01003835 else:
3836 error_name = None
3837
3838 filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator)
Matthew Haddone807aae2021-10-11 18:12:58 +01003839 if filterDict == None:
3840 return []
Matthew Haddon1c00b712021-10-01 15:51:03 +01003841 cleanRankFilter = filterDict['rankFilter']
3842 cleanDtypeFilter = filterDict['dtypeFilter']
3843 cleanShapeFilter = filterDict['shapeFilter']
3844 #print(f"Filters: S {shapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
3845
3846 for r in cleanRankFilter:
Kevin Cheng1533b852021-09-01 12:51:58 -07003847 if opName.startswith("conv3d"):
3848 assert r == 5, "conv3d test must have input rank == 5"
Matthew Haddon1c00b712021-10-01 15:51:03 +01003849 for t in cleanDtypeFilter:
3850 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01003851 # Filter out by rank
3852 if shape is not None and len(shape) != r:
3853 continue
Matthew Haddon74567092021-07-16 15:38:20 +01003854 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003855 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003856
Matthew Haddon74567092021-07-16 15:38:20 +01003857 shapeStr = self.shapeStr(shapeList[0])
3858 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07003859
Matthew Haddon74567092021-07-16 15:38:20 +01003860 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
3861 argList = []
3862 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01003863 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003864 else:
Matthew Haddon74567092021-07-16 15:38:20 +01003865 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07003866
Matthew Haddon74567092021-07-16 15:38:20 +01003867 for argStr, args in argList:
Matthew Haddon1c00b712021-10-01 15:51:03 +01003868 if testType == 'positive':
3869 if argStr:
3870 testStr = "{}_{}_{}_{}".format(
3871 opName, shapeStr, typeStr, argStr
3872 )
3873 else:
3874 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
3875 elif testType == 'negative':
Matthew Haddone86fd342021-09-07 16:12:21 +01003876 if argStr:
3877 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
3878 opName, error_name, shapeStr, typeStr, argStr
3879 )
3880 else:
3881 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003882
3883 testList.append((opName, testStr, t, error_name, shapeList, args))
3884
3885 if testType == 'positive':
3886 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
3887 if "invalid_test_validators" in op:
3888 invalid_test_validators = op["invalid_test_validators"]
3889 clean_testList = []
3890 for test in testList:
3891 for validator_fcn in invalid_test_validators:
3892 remove_test = False
3893 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
3894 remove_test = True
3895 if not remove_test:
3896 clean_testList.append(test)
3897 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07003898
3899 return testList
3900
Matthew Haddone86fd342021-09-07 16:12:21 +01003901
3902 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07003903 try:
3904 op = self.TOSA_OP_LIST[opName]
3905 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003906 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003907
3908 # Create a serializer
3909 self.createSerializer(opName, testStr)
3910
Kevin Cheng550ccc52021-03-03 11:21:43 -08003911 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003912 if "error_if_validators" in op:
3913 error_if_validators = op["error_if_validators"]
3914 else:
3915 error_if_validators = None
3916
Kevin Cheng550ccc52021-03-03 11:21:43 -08003917 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003918 num_operands = pCount + cCount
3919
3920 if isinstance(dtype_or_dtypeList, list):
3921 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07003922 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01003923 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003924 else:
3925 dtypeList = [dtype_or_dtypeList] * (num_operands)
3926
Kevin Cheng93a16282021-08-31 16:14:03 -07003927 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01003928 assert (
3929 len(shapeList) == num_operands
3930 ), "shapeList length {} must match number of operands {}".format(
3931 len(shapeList), num_operands
3932 )
3933 assert (
3934 len(dtypeList) == num_operands
3935 ), "dtypeList length {} must match number of operands {}".format(
3936 len(dtypeList), num_operands
3937 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003938
3939 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003940 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003941 except KeyError:
3942 qgen = None
3943
3944 # Build the random tensor operands and the test
3945 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08003946
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003947 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003948
3949 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003950 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003951 else:
3952 qinfo = None
3953
3954 try:
3955 if error_if_validators is None:
3956 if qinfo is not None:
3957 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
3958 else:
3959 resultName = build_fcn(self, op, *tens, *testArgs)
3960 else:
3961 if qinfo is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003962 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name, qinfo)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003963 else:
3964 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name)
3965 except TypeError as e:
3966 print(
3967 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
3968 build_fcn, tens, testArgs
3969 )
3970 )
3971 raise e
3972
3973 if resultName is None:
3974 print("Invalid ERROR_IF tests created")
3975
3976 # Save the serialized test
3977 self.serialize("test")
3978
3979
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003980 def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003981 pCount, cCount = op["operands"]
3982
3983 tens = []
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003984 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name == None:
Jeremy Johnsonef509a42021-09-07 13:59:47 +01003985 # Make sure the operation does not cause value saturation - where
3986 # the number wraps due to limited number of bits to store the answer
3987 assert (
3988 pCount == 2 and cCount == 0
3989 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01003990 placeholders = []
3991 add = (op["op"] == Op.ADD)
3992 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
3993 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
3994 if add:
3995 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
3996 else:
3997 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
3998
3999 # Work out the saturation limits
4000 max_i32 = (1 << 31)-1
4001 min_i32 = -(1 << 31)
4002 max_arr = np.full(shapeList[1], max_i32)
4003 min_arr = np.full(shapeList[1], min_i32)
4004
4005 # Find how much values exceed the maximum/minimums
4006 sat_max_arr = np.maximum(res_arr - max_arr, 0)
4007 sat_min_arr = np.minimum(res_arr - min_arr, 0)
4008
4009 if not add:
4010 # Swap saturation values and negate values as we need to perform opposite operations
4011 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
4012
4013 # Create new array of unsaturated values by clipping values as needed
4014 b_unsat_arr = b_arr
4015 if (sat_max_arr != 0).any():
4016 # Clip values that cause saturation
4017 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
4018 # Reduce axes in unsaturated tensor to match original tensor
4019 for axis, dim in enumerate(b_arr.shape):
4020 if dim != b_unsat_arr.shape[axis]:
4021 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
4022 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
4023
4024 if (sat_min_arr != 0).any():
4025 # Clip values that cause saturation
4026 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
4027 # Reduce axes in unsaturated tensor to match original tensor
4028 for axis, dim in enumerate(b_arr.shape):
4029 if dim != b_unsat_arr.shape[axis]:
4030 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
4031 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
4032
4033 placeholders.append(
4034 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
4035 )
4036 placeholders.append(
4037 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
4038 )
4039
4040 tens.extend(placeholders)
4041 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
4042 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004043 assert (
4044 pCount == 2 and cCount == 0
4045 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08004046
4047 placeholders = []
4048 for idx, shape in enumerate(shapeList[:]):
4049 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07004050 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08004051 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07004052 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08004053 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07004054 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08004055 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
4056 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004057 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08004058 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004059 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07004060 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08004061
4062 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01004063 elif op["op"] == Op.SELECT:
4064 # Set datatype of condition tensor to boolean
4065 dtypeList[0] = DType.BOOL
4066 tens.extend(
4067 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
4068 )
4069 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004070 elif op["op"] == Op.INTDIV and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004071 assert (
4072 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01004073 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004074
4075 placeholders = []
4076
Matthew Haddon459443c2021-08-23 16:43:13 +01004077 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004078 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07004079 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004080 while True:
4081 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
4082 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
4083
4084 if (divisor_arr == 0).any():
4085 continue
4086
Kevin Cheng47315e12021-05-13 17:41:28 -07004087 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004088 continue
4089
4090 break
4091
4092 placeholders.append(
4093 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
4094 )
4095 placeholders.append(
4096 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
4097 )
4098
4099 tens.extend(placeholders)
4100 elif op["op"] == Op.MUL:
4101 assert (
4102 pCount == 2 and cCount == 0
4103 ), "Op.MUL must have 2 placeholders, 0 consts"
4104
4105 if dtypeList[0] == DType.FLOAT:
4106 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
4107 else:
4108 placeholders = []
4109
4110 # Make sure multiply result in int32 range
4111 shift = testArgs[0]
4112 if dtypeList[0] == DType.INT8:
4113 num_bits = 8
4114 elif dtypeList[0] == DType.INT16:
4115 num_bits = 16
4116 elif dtypeList[0] == DType.INT32:
4117 num_bits = 32
4118 else:
4119 raise Exception("OpMul: invalid input dtype")
4120
4121 for idx, shape in enumerate(shapeList[:]):
4122 low = -(2 ** (num_bits - 1))
4123 high = (2 ** (num_bits - 1)) - 1
4124
4125 a_arr = np.int32(
4126 self.rng.integers(low=low, high=high, size=shapeList[0])
4127 )
4128 b_arr = np.int32(
4129 self.rng.integers(low=low, high=high, size=shapeList[1])
4130 )
4131
4132 i = 0
4133 while True:
4134
4135 a_arr_64 = a_arr.astype(np.int64)
4136 b_arr_64 = b_arr.astype(np.int64)
4137
4138 if shift > 0:
4139 rounding = 1 << (shift - 1)
4140 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
4141 else:
4142 result_arr = a_arr_64 * b_arr_64
4143
4144 if (result_arr > -(2 ** 31)).all() and (
4145 result_arr <= ((2 ** 31) - 1)
4146 ).all():
4147 break
4148
4149 i = i + 1
4150 a_arr = a_arr // 2
4151 b_arr = b_arr // 2
4152
4153 placeholders.append(
4154 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
4155 )
4156 placeholders.append(
4157 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
4158 )
4159
4160 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01004161 elif op["op"] == Op.CONCAT:
4162 count = len(shapeList) - self.args.num_const_inputs_concat
4163 if count < 1:
4164 count = 1
4165 if self.args.num_const_inputs_concat == 0:
4166 count = len(shapeList)
4167
4168 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
4169 tens.extend(
4170 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
4171 )
4172 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08004173 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07004174 tens.extend(
4175 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
4176 )
4177 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07004178
Matthew Haddon1c00b712021-10-01 15:51:03 +01004179 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07004180
4181 def createDynamicOpLists(self):
4182
4183 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07004184 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004185
Kevin Cheng1533b852021-09-01 12:51:58 -07004186 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004187 testName = "conv2d_{}x{}".format(k[0], k[1])
4188 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
4189 self.TOSA_OP_LIST[testName]["filter"] = k
4190 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07004191
Kevin Cheng550ccc52021-03-03 11:21:43 -08004192 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
4193 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
4194 "depthwise_conv2d_TEMPLATE"
4195 ].copy()
4196 self.TOSA_OP_LIST[testName]["filter"] = k
4197 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07004198
Kevin Cheng550ccc52021-03-03 11:21:43 -08004199 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
4200 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
4201 "transpose_conv2d_TEMPLATE"
4202 ].copy()
4203 self.TOSA_OP_LIST[testName]["filter"] = k
4204 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07004205
Kevin Cheng1533b852021-09-01 12:51:58 -07004206 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
4207 for k in KERNELS_3D:
4208 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
4209 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
4210 self.TOSA_OP_LIST[testName]["filter"] = k
4211 self.TOSA_OP_LIST[testName]["template"] = False
4212
Eric Kunzee5e26762020-10-13 16:11:07 -07004213 # Delete any templates after having created any dynamic ops
4214 # This is a two-pass operation because it's bad practice to delete
4215 # keys from dictionaries while iterating
4216 keyList = []
4217 for k in self.TOSA_OP_LIST:
4218 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004219 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07004220 keyList.append(k)
4221 continue
4222 except KeyError:
4223 pass
4224
4225 for k in keyList:
4226 del self.TOSA_OP_LIST[k]
4227
4228 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004229 """Fill in default fields for ops if they aren't already specified.
4230 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07004231 for op in self.TOSA_OP_LIST:
4232
4233 # Required fields
4234 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004235 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004236 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004237 raise Exception(
4238 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
4239 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004240
4241 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004242 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004243 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004244 raise Exception(
4245 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
4246 op
4247 )
4248 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004249
4250 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004251 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004252 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004253 raise Exception(
4254 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
4255 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004256
4257 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004258 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004259 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004260 raise Exception(
4261 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
4262 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004263
4264 # Put in default rank range, if missing
4265 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004266 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004267 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004268 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07004269
4270 # Tensor operator list
4271 # 'op': op name
4272 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08004273 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
4274 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07004275 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
4276 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08004277 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07004278
Kevin Cheng550ccc52021-03-03 11:21:43 -08004279 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
4280 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07004281
Kevin Cheng550ccc52021-03-03 11:21:43 -08004282 TYPE_BOOL = [DType.BOOL]
4283 TYPE_FI32 = [DType.FLOAT, DType.INT32]
4284 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
4285 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07004286
Kevin Cheng550ccc52021-03-03 11:21:43 -08004287 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07004288
Kevin Cheng1533b852021-09-01 12:51:58 -07004289 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07004290 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07004291 [DType.INT8, DType.INT8, DType.INT32],
4292 [DType.INT16, DType.INT8, DType.INT48],
4293 DType.FLOAT,
4294 ]
4295
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01004296 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07004297
4298 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08004299 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004300 "argmax": {
4301 "op": Op.ARGMAX,
4302 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004303 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004304 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4305 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004306 "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evArgmaxOutputRankMismatch,
4307 TosaErrorValidator.evArgmaxOutputShapeMismatch, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
4308 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004309 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004310 "avg_pool2d": {
4311 "op": Op.AVG_POOL2D,
4312 "operands": (1, 0),
4313 "rank": (4, 4),
4314 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
4315 "qgen": TosaQuantGen.qgUnary,
4316 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004317 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
4318 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
4319 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
4320 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
4321 TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08004322 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004323 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08004324 "conv2d_TEMPLATE": {
4325 "op": Op.CONV2D,
4326 "operands": (1, 2),
4327 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01004328 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004329 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004330 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01004331 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004332 "template": True,
4333 },
Kevin Cheng1533b852021-09-01 12:51:58 -07004334 # Templated operator. Filled in by createDynamicOpLists
4335 "conv3d_TEMPLATE": {
4336 "op": Op.CONV3D,
4337 "operands": (1, 2),
4338 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01004339 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07004340 "qgen": TosaQuantGen.qgConv,
4341 "types": TYPE_CONV,
4342 "template": True,
4343 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004344 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08004345 "depthwise_conv2d_TEMPLATE": {
4346 "op": Op.DEPTHWISE_CONV2D,
4347 "operands": (1, 2),
4348 "filter": [1, 1],
4349 "rank": (4, 4),
4350 "build_fcn": (
4351 build_depthwise_conv2d,
4352 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01004353 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004354 ),
4355 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004356 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01004357 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004358 "template": True,
4359 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004360 "fully_connected": {
4361 "op": Op.FULLY_CONNECTED,
4362 "operands": (1, 2),
4363 "rank": (2, 2),
4364 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
4365 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004366 "types": TYPE_CONV,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004367 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evWrongRank,
4368 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004369 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004370 "matmul": {
4371 "op": Op.MATMUL,
4372 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07004373 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08004374 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
4375 "qgen": TosaQuantGen.qgMatmul,
4376 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004377 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
4378 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004379 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004380 "max_pool2d": {
4381 "op": Op.MAX_POOL2D,
4382 "operands": (1, 0),
4383 "rank": (4, 4),
4384 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
4385 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004386 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
4387 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
4388 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
4389 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08004390 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004391 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08004392 "transpose_conv2d_TEMPLATE": {
4393 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07004394 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004395 "rank": (4, 4),
4396 "build_fcn": (
4397 build_transpose_conv2d,
4398 TosaTensorGen.tgTransposeConv2D,
4399 TosaArgGen.agTransposeConv2D,
4400 ),
4401 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004402 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01004403 "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004404 "template": True,
4405 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004406 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08004407 "clamp": {
4408 "op": Op.CLAMP,
4409 "operands": (1, 0),
4410 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
4411 "types": TYPE_NARROW_INT_FP,
4412 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004413 "sigmoid": {
4414 "op": Op.SIGMOID,
4415 "operands": (1, 0),
4416 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
4417 "types": TYPE_FP,
4418 },
4419 "tanh": {
4420 "op": Op.TANH,
4421 "operands": (1, 0),
4422 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
4423 "types": TYPE_FP,
4424 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004425 # Elementwise Binary Operators
4426 "add": {
4427 "op": Op.ADD,
4428 "operands": (2, 0),
4429 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4430 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004431 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4432 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004433 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004434 "arithmetic_right_shift": {
4435 "op": Op.ARITHMETIC_RIGHT_SHIFT,
4436 "operands": (2, 0),
4437 "build_fcn": (
4438 build_arithmetic_right_shift,
4439 TosaTensorGen.tgBroadcastFuzz,
4440 TosaArgGen.agArithmeticRightShift,
4441 ),
4442 "types": TYPE_INT,
4443 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004444 "bitwise_and": {
4445 "op": Op.BITWISE_AND,
4446 "operands": (2, 0),
4447 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4448 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004449 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4450 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004451 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004452 "bitwise_or": {
4453 "op": Op.BITWISE_OR,
4454 "operands": (2, 0),
4455 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4456 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004457 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4458 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004459 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004460 "bitwise_xor": {
4461 "op": Op.BITWISE_XOR,
4462 "operands": (2, 0),
4463 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4464 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004465 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4466 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004467 },
Matthew Haddon459443c2021-08-23 16:43:13 +01004468 "intdiv": {
4469 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004470 "operands": (2, 0),
4471 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4472 "types": [DType.INT32],
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004473 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4474 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004475 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004476 "logical_and": {
4477 "op": Op.LOGICAL_AND,
4478 "operands": (2, 0),
4479 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4480 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004481 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4482 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004483 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004484 "logical_left_shift": {
4485 "op": Op.LOGICAL_LEFT_SHIFT,
4486 "operands": (2, 0),
4487 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4488 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004489 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4490 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004491 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004492 "logical_right_shift": {
4493 "op": Op.LOGICAL_RIGHT_SHIFT,
4494 "operands": (2, 0),
4495 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4496 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004497 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4498 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004499 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004500 "logical_or": {
4501 "op": Op.LOGICAL_OR,
4502 "operands": (2, 0),
4503 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4504 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004505 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4506 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004507 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004508 "logical_xor": {
4509 "op": Op.LOGICAL_XOR,
4510 "operands": (2, 0),
4511 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4512 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004513 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4514 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004515 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004516 "maximum": {
4517 "op": Op.MAXIMUM,
4518 "operands": (2, 0),
4519 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4520 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004521 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4522 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004523 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004524 "minimum": {
4525 "op": Op.MINIMUM,
4526 "operands": (2, 0),
4527 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4528 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004529 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4530 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004531 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004532 "mul": {
4533 "op": Op.MUL,
4534 "operands": (2, 0),
4535 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
4536 "types": TYPE_INT_FP,
4537 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004538 "pow": {
4539 "op": Op.POW,
4540 "operands": (2, 0),
4541 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
4542 "types": TYPE_FP,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004543 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4544 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004545 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004546 "sub": {
4547 "op": Op.SUB,
4548 "operands": (2, 0),
4549 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4550 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004551 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4552 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004553 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004554 "table": {
4555 "op": Op.TABLE,
4556 # Use the automatic generation functions to create the input array
4557 # but create the table tensor in the build function, as it may be
4558 # a different type from the input
4559 "operands": (1, 0),
4560 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004561 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08004562 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004563 # Elementwise Unary operators
4564 "abs": {
4565 "op": Op.ABS,
4566 "operands": (1, 0),
4567 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4568 "types": TYPE_FI32,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004569 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4570 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004571 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004572 "bitwise_not": {
4573 "op": Op.BITWISE_NOT,
4574 "operands": (1, 0),
4575 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4576 "types": TYPE_INT,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004577 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4578 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004579 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004580 "ceil": {
4581 "op": Op.CEIL,
4582 "operands": (1, 0),
4583 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4584 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004585 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4586 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004587 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004588 "clz": {
4589 "op": Op.CLZ,
4590 "operands": (1, 0),
4591 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4592 "types": [DType.INT32],
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004593 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4594 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004595 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004596 "exp": {
4597 "op": Op.EXP,
4598 "operands": (1, 0),
4599 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4600 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004601 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4602 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004603 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004604 "floor": {
4605 "op": Op.FLOOR,
4606 "operands": (1, 0),
4607 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4608 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004609 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4610 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004611 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004612 "log": {
4613 "op": Op.LOG,
4614 "operands": (1, 0),
4615 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4616 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004617 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4618 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004619 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004620 "logical_not": {
4621 "op": Op.LOGICAL_NOT,
4622 "operands": (1, 0),
4623 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4624 "types": TYPE_BOOL,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004625 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4626 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004627 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004628 "negate": {
4629 "op": Op.NEGATE,
4630 "operands": (1, 0),
4631 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4632 "qgen": TosaQuantGen.qgUnary,
4633 "types": TYPE_INT_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004634 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
4635 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
4636 TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004637 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004638 "reciprocal": {
4639 "op": Op.RECIPROCAL,
4640 "operands": (1, 0),
4641 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4642 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004643 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4644 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004645 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004646 "rsqrt": {
4647 "op": Op.RSQRT,
4648 "operands": (1, 0),
4649 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4650 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004651 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4652 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004653 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004654 # Elementwise Ternary operators
4655 "select": {
4656 "op": Op.SELECT,
4657 "operands": (3, 0),
4658 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
4659 "types": TYPE_FIB,
4660 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004661 # Comparison operators
4662 "equal": {
4663 "op": Op.EQUAL,
4664 "operands": (2, 0),
4665 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4666 "types": TYPE_FI32,
4667 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004668 "greater_equal": {
4669 "op": Op.GREATER_EQUAL,
4670 "operands": (2, 0),
4671 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4672 "types": TYPE_FI32,
4673 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004674 "greater": {
4675 "op": Op.GREATER,
4676 "operands": (2, 0),
4677 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4678 "types": TYPE_FI32,
4679 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004680 # Reduction operators
4681 "reduce_all": {
4682 "op": Op.REDUCE_ALL,
4683 "operands": (1, 0),
4684 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4685 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004686 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4687 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4688 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004689 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004690 "reduce_any": {
4691 "op": Op.REDUCE_ANY,
4692 "operands": (1, 0),
4693 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4694 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004695 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4696 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4697 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004698 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004699 "reduce_max": {
4700 "op": Op.REDUCE_MAX,
4701 "operands": (1, 0),
4702 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4703 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004704 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4705 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4706 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004707 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004708 "reduce_min": {
4709 "op": Op.REDUCE_MAX,
4710 "operands": (1, 0),
4711 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4712 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004713 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4714 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4715 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004716 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004717 "reduce_product": {
4718 "op": Op.REDUCE_PRODUCT,
4719 "operands": (1, 0),
4720 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4721 "types": TYPE_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004722 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4723 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4724 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004725 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004726 "reduce_sum": {
4727 "op": Op.REDUCE_SUM,
4728 "operands": (1, 0),
4729 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4730 "types": TYPE_FI32,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004731 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4732 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4733 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004734 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004735 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004736 "concat": {
4737 "op": Op.CONCAT,
4738 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01004739 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004740 "types": TYPE_FIB,
4741 },
4742 "pad": {
4743 "op": Op.PAD,
4744 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01004745 "rank": (1, 5),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004746 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
4747 "qgen": TosaQuantGen.qgPad,
4748 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004749 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evPadSmallerZero,
4750 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004751 },
4752 "reshape": {
4753 "op": Op.RESHAPE,
4754 "operands": (1, 0),
4755 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
4756 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004757 "error_if_validators": (TosaErrorValidator.evTensorSizeInputOutputMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4758 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004759 },
4760 "reverse": {
4761 "op": Op.REVERSE,
4762 "operands": (1, 0),
4763 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4764 "types": TYPE_FIB,
4765 },
4766 "slice": {
4767 "op": Op.SLICE,
4768 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01004769 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004770 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
4771 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004772 "error_if_validators": (TosaErrorValidator.evStartSmallerZero, TosaErrorValidator.evSizeSmallerEqualZero, TosaErrorValidator.evStartSizeOutsideBounds,
4773 TosaErrorValidator.evSizeOutputShapeMismatch, TosaErrorValidator.evInputSizeStartLengthMismatch, TosaErrorValidator.evWrongRank,
4774 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004775 },
4776 "tile": {
4777 "op": Op.TILE,
4778 "operands": (1, 0),
4779 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
4780 "types": TYPE_FIB,
4781 },
4782 "transpose": {
4783 "op": Op.TRANSPOSE,
4784 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01004785 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004786 "build_fcn": (
4787 build_transpose,
4788 TosaTensorGen.tgBasic,
4789 TosaArgGen.agTranspose,
4790 ),
4791 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004792 "error_if_validators": (TosaErrorValidator.evIndexOutsideBounds, TosaErrorValidator.evIndexUsedTwice, TosaErrorValidator.evWrongRank,
4793 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004794 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004795 # Data nodes
4796 "const": {
4797 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004798 "operands": (0, 1),
4799 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08004800 "types": TYPE_FIB,
4801 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004802 "identity": {
4803 "op": Op.IDENTITY,
4804 "operands": (1, 0),
4805 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4806 "types": TYPE_FIB,
4807 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004808 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004809 "gather": {
4810 "op": Op.GATHER,
4811 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4812 "operands": (1, 0),
4813 "rank": (3, 3),
4814 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
4815 "types": TYPE_INT_FP,
4816 },
4817 "scatter": {
4818 "op": Op.SCATTER,
4819 # Only specify 'values_in' tensor here.
4820 #'indices' and 'input' are generated in op building stage
4821 "operands": (2, 0),
4822 "rank": (3, 3),
4823 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
4824 "types": TYPE_INT_FP,
4825 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004826 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004827 "resize": {
4828 "op": Op.RESIZE,
4829 "operands": (1, 0),
4830 "rank": (4, 4),
4831 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
4832 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01004833 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
4834 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
4835 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
Matthew Haddon848efb42021-09-09 12:30:53 +01004836 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004837 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
4838 TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004839 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004840 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004841 "cast": {
4842 "op": Op.CAST,
4843 "operands": (1, 0),
4844 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
4845 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
4846 },
4847 "rescale": {
4848 "op": Op.RESCALE,
4849 "operands": (1, 0),
Matthew Haddonc2025212021-10-08 21:21:05 +01004850 "rank": (1,4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004851 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004852 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Matthew Haddonc2025212021-10-08 21:21:05 +01004853 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, TosaErrorValidator.evScaleTrue,
4854 TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4855 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004856 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004857 # Custom
4858 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004859 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004860 # Two varients of cond_if, one that generates one of two constant tensors (no
4861 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4862 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004863 "cond_if_const": {
4864 "op": Op.COND_IF,
4865 "operands": (0, 2),
4866 "build_fcn": (
4867 build_cond_if_const,
4868 TosaTensorGen.tgBasic,
4869 TosaArgGen.agCondIf,
4870 ),
4871 "types": [DType.BOOL],
4872 },
4873 "cond_if_binary": {
4874 "op": Op.COND_IF,
4875 "operands": (2, 0),
4876 "build_fcn": (
4877 build_cond_if_binary,
4878 TosaTensorGen.tgBasic,
4879 TosaArgGen.agCondIf,
4880 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004881 "types": TYPE_INT_FP,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004882 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004883 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004884 "while_loop": {
4885 "op": Op.WHILE_LOOP,
4886 "operands": (0, 1),
4887 "build_fcn": (
4888 build_while_loop,
4889 TosaTensorGen.tgBasic,
4890 TosaArgGen.agWhileLoop,
4891 ),
4892 "types": [DType.INT32],
4893 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004894 }
4895
Kevin Cheng550ccc52021-03-03 11:21:43 -08004896
Eric Kunzee5e26762020-10-13 16:11:07 -07004897class OutputShaper:
4898 # Methods in this class compute the expected output shape and datatype
4899 # for common classes of operations
4900 def __init__(self):
4901 pass
4902
4903 # These methods return arguments that can be used for
4904 # creating a new output tensor
4905 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004906 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4907 if error_name != ErrorIf.RankMismatch:
4908 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004909 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004910
4911 shape = []
4912 for i in range(len(a.shape)):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004913 if a.shape[i] == 1 and error_name == None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004914 shape.append(b.shape[i])
4915 else:
4916 shape.append(a.shape[i])
4917
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004918 if error_name == ErrorIf.WrongOutputType:
4919 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4920 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4921 outputDType = rng.choice(wrong_dtypes)
4922 else:
4923 outputDType = a.dtype
4924
4925 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004926
4927 @staticmethod
4928 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004929 assert len(a.shape) == len(b.shape)
4930 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004931
4932 shape = []
4933 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004934 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004935 shape.append(a.shape[i])
4936
Kevin Cheng550ccc52021-03-03 11:21:43 -08004937 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004938
4939 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004940 def unaryOp(ser, rng, a, error_name=None):
4941 if error_name == ErrorIf.WrongOutputType:
4942 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4943 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4944 outputDType = rng.choice(wrong_dtypes)
4945 else:
4946 outputDType = a.dtype
4947
4948 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004949
4950 @staticmethod
4951 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004952 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
4953 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004954
4955 shape = []
4956 for i in range(len(a.shape)):
4957 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4958
Kevin Cheng550ccc52021-03-03 11:21:43 -08004959 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004960
4961 @staticmethod
4962 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004963 assert len(a.shape) == len(b.shape)
4964 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004965
4966 # Do broadcast
4967 shape = []
4968 for i in range(len(a.shape)):
4969 if a.shape[i] == 1:
4970 shape.append(b.shape[i])
4971 else:
4972 shape.append(a.shape[i])
4973
4974 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08004975 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07004976
4977 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004978 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004979 shape = a.shape.copy()
Matthew Haddond6ce7252021-09-29 15:35:44 +01004980 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ShapeOfAxisNotOne]:
4981 shape[axis] = 1
4982 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4983 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004984
Matthew Haddond6ce7252021-09-29 15:35:44 +01004985 if error_name == ErrorIf.WrongOutputType:
4986 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4987 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4988 outputDType = rng.choice(wrong_dtypes)
4989 else:
4990 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004991
Matthew Haddond6ce7252021-09-29 15:35:44 +01004992 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004993
4994 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004995 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004996 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004997
4998 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4999 del shape[axis]
5000
5001 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5002 remove = rng.choice([True, False])
5003 if remove and len(shape) > 1:
5004 del shape[0]
5005 else:
5006 shape.append(1)
5007 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5008 for i in range(len(shape)):
5009 shape[i] = shape[i] + rng.integers(1, 10)
5010
5011 if error_name == ErrorIf.WrongOutputType:
5012 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5013 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5014 outputDType = rng.choice(wrong_dtypes)
5015 else:
5016 outputDType = DType.INT32
5017
5018 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005019
5020 @staticmethod
5021 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
5022
5023 # IFM: NHWC
5024 # Filter: OHWI
5025 # OFM: NHWC
5026
5027 if len(padding) == 2:
5028 # Expand padding to 4 parameters in the case of transpose_conv2d
5029 # From H,W to T,B,L,R
5030 padding = [padding[0], padding[0], padding[1], padding[1]]
5031
Kevin Cheng550ccc52021-03-03 11:21:43 -08005032 h = (
5033 ifm.shape[1]
5034 - filter.shape[1]
5035 - (filter.shape[1] - 1) * (dilations[0] - 1)
5036 + padding[0]
5037 + padding[1]
5038 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005039
Kevin Cheng550ccc52021-03-03 11:21:43 -08005040 w = (
5041 ifm.shape[2]
5042 - filter.shape[2]
5043 - (filter.shape[2] - 1) * (dilations[1] - 1)
5044 + padding[2]
5045 + padding[3]
5046 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005047
Eric Kunzee5e26762020-10-13 16:11:07 -07005048 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5049
Kevin Cheng3a478572021-01-22 17:21:02 -08005050 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005051 out_dtype = DType.INT32
5052 elif ifm.dtype == DType.INT16:
5053 out_dtype = DType.INT48
5054 elif ifm.dtype == DType.FLOAT:
5055 out_dtype = DType.FLOAT
5056 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005057 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005058
Kevin Cheng550ccc52021-03-03 11:21:43 -08005059 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005060
5061 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -07005062 def conv3dOp(ser, ifm, filter, strides, padding, dilations):
5063
5064 # IFM: NDHWC
5065 # Filter: ODHWI
5066 # OFM: NDHWC
5067
5068 d = (
5069 ifm.shape[1]
5070 - filter.shape[1]
5071 - (filter.shape[1] - 1) * (dilations[0] - 1)
5072 + padding[0]
5073 + padding[1]
5074 ) // strides[0] + 1
5075
5076 h = (
5077 ifm.shape[2]
5078 - filter.shape[2]
5079 - (filter.shape[2] - 1) * (dilations[1] - 1)
5080 + padding[2]
5081 + padding[3]
5082 ) // strides[1] + 1
5083
5084 w = (
5085 ifm.shape[3]
5086 - filter.shape[3]
5087 - (filter.shape[3] - 1) * (dilations[2] - 1)
5088 + padding[4]
5089 + padding[5]
5090 ) // strides[2] + 1
5091
5092 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5093
5094 if ifm.dtype == DType.INT8:
5095 out_dtype = DType.INT32
5096 elif ifm.dtype == DType.INT16:
5097 out_dtype = DType.INT48
5098 elif ifm.dtype == DType.FLOAT:
5099 out_dtype = DType.FLOAT
5100 else:
5101 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
5102
5103 return ser.addOutput(ofm_shape, out_dtype)
5104
5105 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -07005106 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
5107 # IFM: NHWC
5108 # Filter: HWCM
5109 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08005110 h = (
5111 ifm.shape[1]
5112 - filter.shape[0]
5113 - (filter.shape[0] - 1) * (dilations[0] - 1)
5114 + padding[0]
5115 + padding[1]
5116 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005117
Kevin Cheng550ccc52021-03-03 11:21:43 -08005118 w = (
5119 ifm.shape[2]
5120 - filter.shape[1]
5121 - (filter.shape[1] - 1) * (dilations[1] - 1)
5122 + padding[2]
5123 + padding[3]
5124 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005125
Eric Kunzee5e26762020-10-13 16:11:07 -07005126 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5127
Kevin Cheng3a478572021-01-22 17:21:02 -08005128 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005129 out_dtype = DType.INT32
5130 elif ifm.dtype == DType.INT16:
5131 out_dtype = DType.INT48
5132 elif ifm.dtype == DType.FLOAT:
5133 out_dtype = DType.FLOAT
5134 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005135 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005136
Kevin Cheng550ccc52021-03-03 11:21:43 -08005137 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005138
5139 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005140 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005141 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005142 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005143 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005144 h = 1
5145 w = 1
5146 else:
5147 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
5148 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
5149
5150 if error_name == ErrorIf.PoolingOutputShapeMismatch:
5151 choices = [1, 2, 3, 4, 5]
5152 h = h + rng.choice(choices)
5153 w = w + rng.choice(choices)
Eric Kunzee5e26762020-10-13 16:11:07 -07005154
Eric Kunzee5e26762020-10-13 16:11:07 -07005155 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005156
5157 if error_name == ErrorIf.WrongOutputType:
5158 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5159 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5160 outputDType = rng.choice(wrong_dtypes)
5161 else:
5162 outputDType = ifm.dtype
5163
5164 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005165
5166 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005167 def fullyConnectedOp(ser, rng, input, filter, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005168 # input: N, IC
5169 # filter: OC, IC
5170 # output: N, OC
5171
5172 output_shape = [input.shape[0], filter.shape[0]]
5173
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005174 if error_name == ErrorIf.WrongOutputType:
5175 if input.dtype == DType.INT8:
5176 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
5177 elif input.dtype == DType.INT16:
5178 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
5179 elif input.dtype == DType.FLOAT:
5180 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
5181 out_dtype = rng.choice(a=incorrect_types)
5182 elif input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005183 out_dtype = DType.INT32
5184 elif input.dtype == DType.INT16:
5185 out_dtype = DType.INT48
5186 elif input.dtype == DType.FLOAT:
5187 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005188 elif error_name == ErrorIf.WrongInputType:
5189 # Pick some potentially correct output dtype if input type is incorrect
5190 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005191 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005192 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005193
Kevin Cheng550ccc52021-03-03 11:21:43 -08005194 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005195
5196 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005197 def matmulOp(ser, rng, a, b, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005198 # a: N, H, C
5199 # b: N, C, W
5200 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005201
Kevin Cheng2d60f002021-06-09 14:18:32 -07005202 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005203
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005204 if error_name == ErrorIf.WrongOutputType:
5205 if a.dtype == DType.INT8:
5206 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
5207 elif a.dtype == DType.INT16:
5208 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
5209 elif a.dtype == DType.FLOAT:
5210 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
5211 out_dtype = rng.choice(a=incorrect_types)
5212 elif a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005213 out_dtype = DType.INT32
5214 elif a.dtype == DType.INT16:
5215 out_dtype = DType.INT48
5216 elif a.dtype == DType.FLOAT:
5217 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005218 elif error_name == ErrorIf.WrongInputType:
5219 # Pick some potentially correct output dtype if input type is incorrect
5220 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005221 else:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005222 raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005223
Kevin Cheng550ccc52021-03-03 11:21:43 -08005224 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005225
5226 @staticmethod
Matthew Haddon818ab902021-07-27 09:12:49 +01005227 def concatOp(ser, axis, *a):
5228 input1 = a[0]
5229 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005230
Matthew Haddon818ab902021-07-27 09:12:49 +01005231 output_shape = input1.shape.copy()
Eric Kunzee5e26762020-10-13 16:11:07 -07005232
Matthew Haddon818ab902021-07-27 09:12:49 +01005233 output_shape[axis] = input1.shape[axis]
5234
5235 for tensor in remaining_inputs:
5236 output_shape[axis] += tensor.shape[axis]
5237
5238 return ser.addOutput(output_shape, input1.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005239
5240 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005241 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005242
5243 output_shape = a.shape.copy()
5244
5245 for i in range(len(output_shape)):
5246 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5247
Matthew Haddone807aae2021-10-11 18:12:58 +01005248 # Fix negative output shape if error_if test causes it
5249 if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
5250 output_shape = [i if i >= 1 else 1 for i in output_shape]
5251
5252 if error_name == ErrorIf.WrongOutputType:
5253 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5254 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5255 outputDType = rng.choice(wrong_dtypes)
5256 else:
5257 outputDType = a.dtype
5258
5259 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005260
5261 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005262 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005263 output_shape = shape.copy()
5264
5265 totalElements = 1
5266 for i in a.shape:
5267 totalElements *= i
5268
5269 # If there are any -1 elements, figure out what that dimension must be
5270 totalOutputElements = 1
5271 for i in output_shape:
5272 if i != -1:
5273 totalOutputElements *= i
5274
5275 # And fill it in
5276 for i in range(len(output_shape)):
5277 if output_shape[i] == -1:
5278 output_shape[i] = totalElements // totalOutputElements
5279
Matthew Haddone807aae2021-10-11 18:12:58 +01005280 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5281 for i in range(len(output_shape)):
5282 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5283
5284 if error_name == ErrorIf.WrongOutputType:
5285 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5286 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5287 outputDType = rng.choice(wrong_dtypes)
5288 else:
5289 outputDType = a.dtype
5290
5291 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005292
5293 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005294 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005295
Matthew Haddone807aae2021-10-11 18:12:58 +01005296 if error_name == ErrorIf.WrongOutputType:
5297 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5298 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5299 outputDType = rng.choice(wrong_dtypes)
5300 else:
5301 outputDType = a.dtype
5302
5303 if error_name == ErrorIf.SizeOutputShapeMismatch:
5304 output_shape = size.copy()
5305 for index in range(len(output_shape)):
5306 if output_shape[index] <= 2:
5307 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5308 else:
5309 output_shape[index] = output_shape[index] + rng.choice([-2, -1, 1, 2])
5310 else:
5311 output_shape = size.copy()
5312
5313 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005314
5315 @staticmethod
5316 def tileOp(ser, a, multiples):
5317
5318 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005319 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005320
5321 for i in range(len(output_shape)):
5322 output_shape[i] = a.shape[i] * multiples[i]
5323
Kevin Cheng550ccc52021-03-03 11:21:43 -08005324 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005325
5326 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005327 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005328 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005329
Kevin Cheng550ccc52021-03-03 11:21:43 -08005330 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005331
Matthew Haddone807aae2021-10-11 18:12:58 +01005332 if error_name == ErrorIf.IndexOutsideBounds:
5333 for i in range(len(output_shape)):
5334 output_shape[i] = a.shape[0]
5335 else:
5336 for i in range(len(output_shape)):
5337 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005338
Matthew Haddone807aae2021-10-11 18:12:58 +01005339 if error_name == ErrorIf.WrongOutputType:
5340 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5341 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5342 outputDType = rng.choice(wrong_dtypes)
5343 else:
5344 outputDType = a.dtype
5345
5346 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005347
5348 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08005349 def gatherOp(ser, values, indices):
5350 assert len(values.shape) == 3
5351 assert len(indices.shape) == 2
5352 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005353
Kevin Cheng77d0f762020-11-24 10:26:32 -08005354 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5355
Kevin Cheng550ccc52021-03-03 11:21:43 -08005356 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005357
5358 @staticmethod
5359 def scatterOp(ser, values_in, indices, input):
5360 assert len(values_in.shape) == 3
5361 assert len(indices.shape) == 2
5362 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005363 assert values_in.shape[0] == indices.shape[0] # N
5364 assert input.shape[1] == indices.shape[1] # W
5365 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005366
5367 output_shape = values_in.shape
5368
Kevin Cheng550ccc52021-03-03 11:21:43 -08005369 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005370
5371 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005372 def tableOp(ser, input, table_dtype):
5373 # Same shape as the input, but dtype dependent on table dtype
5374 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
5375 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
5376 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005377
5378 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005379 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005380 serializer,
5381 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005382 input,
5383 mode,
5384 stride,
5385 offset,
5386 shift,
5387 stride_fp,
5388 offset_fp,
5389 output_dims,
5390 input_dtype,
5391 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01005392 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08005393 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01005394 if error_name == ErrorIf.WrongRank:
5395 output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
5396 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005397 if error_name == ErrorIf.BatchMismatch:
5398 output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
5399 elif error_name == ErrorIf.ChannelMismatch:
5400 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
5401 else:
5402 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005403
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005404 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005405
5406 @staticmethod
5407 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005408 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005409
5410 @staticmethod
5411 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08005412 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005413 out_dtype = DType.INT32
5414 elif ifm.dtype == DType.INT16:
5415 out_dtype = DType.INT48
5416 elif ifm.dtype == DType.FLOAT:
5417 out_dtype = DType.FLOAT
5418 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005419 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005420
Kevin Cheng550ccc52021-03-03 11:21:43 -08005421 return ser.addOutput(output_shape, out_dtype)