blob: 3702142f9caeff80e4a8e617e652c53b02b8f7f3 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
48DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080049Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070050ResizeMode = tosa.ResizeMode.ResizeMode()
51
52class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
54
Eric Kunzee5e26762020-10-13 16:11:07 -070055 def __init__(self):
56 pass
57
58 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010059 def getQinfo(testGen, dtype, error_name=None):
60
Les Bell30e46802021-07-23 09:43:31 +010061 if dtype == DType.INT8:
62 return testGen.randInt(-128, 128)
Matthew Haddone4ecdb22021-09-28 11:38:21 +010063 elif dtype == DType.UINT8:
Les Bell30e46802021-07-23 09:43:31 +010064 return testGen.randInt(0, 256)
Matthew Haddonc4cf0372021-10-11 09:38:10 +010065 elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.WeightZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
Matthew Haddone4ecdb22021-09-28 11:38:21 +010066 zero_point = testGen.randInt(-128, 128)
67 if zero_point == 0:
68 zero_point = 1
69 return zero_point
Les Bell30e46802021-07-23 09:43:31 +010070 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070071
72 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010073 def qgUnary(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070074 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone4ecdb22021-09-28 11:38:21 +010075 if error_name == ErrorIf.InputZeroPointNotZero:
76 qinfo.UnaryQuantInfo(
77 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype)
78 )
79 elif error_name == ErrorIf.OutputZeroPointNotZero:
80 qinfo.UnaryQuantInfo(
81 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype, error_name)
82 )
83 else:
84 qinfo.UnaryQuantInfo(
85 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
86 )
Eric Kunzee5e26762020-10-13 16:11:07 -070087 return qinfo
88
89 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010090 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070091 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010092 if isinstance(dtype_or_dtypeList, list):
93 # a list of [input, weights, accumulator] dtypes
94 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070095 else:
Les Bell30e46802021-07-23 09:43:31 +010096 # an int, [input, weights, accumulator] dtypes are the same
97 dtypeList = [dtype_or_dtypeList] * 3
Matthew Haddonc4cf0372021-10-11 09:38:10 +010098
99 if error_name == ErrorIf.InputZeroPointNotZero:
100 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
101 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
102 elif error_name == ErrorIf.WeightZeroPointNotZero:
103 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
104 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
105 else:
106 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
107 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
108
Les Bell30e46802021-07-23 09:43:31 +0100109 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -0700110 return qinfo
111
112 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100113 def qgMatmul(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700114 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100115 if error_name == ErrorIf.InputZeroPointNotZero:
116 qinfo.MatMulQuantInfo(
117 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype, error_name)
Kevin Chengacb550f2021-06-29 15:32:19 -0700118 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100119 else:
120 qinfo.MatMulQuantInfo(
121 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
122 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700123 return qinfo
124
125 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100126 def qgPad(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700127 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone807aae2021-10-11 18:12:58 +0100128 if error_name == ErrorIf.InputZeroPointNotZero:
129 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name))
130 else:
131 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700132 return qinfo
133
134 @staticmethod
135 def computeMultiplierAndShift(scaleFp, scale32):
136 # Derived from computeMultiplierAndShiftTosaScale32
137 # Provide a floating-point scaling factor and the scale32 parameter
138 # to compute the multiplier and shift
139
140 if scale32:
141 scaleBits = 31
142 else:
143 scaleBits = 15
144
145 m, shift = math.frexp(scaleFp)
146
147 if scaleFp < 0.0:
148 m = -m
149
150 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800151 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700152
153 if multiplier == (1 << scaleBits):
154 multiplier = multiplier // 2
155 shift = shift + 1
156
157 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100158 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
159
160 # Adjust multiplier such that shift is in allowed value range.
161 if shift == 0:
162 multiplier = multiplier // 4
163 shift = shift + 2
164 elif shift == 1:
165 multiplier = multiplier // 2
166 shift = shift + 1
167 elif shift == 63:
168 multiplier = multiplier * 2
169 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700170
Kevin Cheng550ccc52021-03-03 11:21:43 -0800171 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100172 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700173
174 return multiplier, shift
175
176
Kevin Cheng550ccc52021-03-03 11:21:43 -0800177class TosaTensorGen:
178 """Tensor generators create a shape list for the placeholder and const tensor
179 data operands for the operator. The actual random data is generated separately for each test."""
180
Eric Kunzee5e26762020-10-13 16:11:07 -0700181 def __init__(self):
182 pass
183
184 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100185 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800186 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700187 shape = testGen.makeShape(rank)
188
Matthew Haddonc2025212021-10-08 21:21:05 +0100189 # Constrict dimension size for large ranks when creating WrongRank tests
190 shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 shape_list = []
193 for i in range(pl + const):
194 shape_list.append(shape.copy())
195
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100196 if error_name == ErrorIf.RankMismatch:
197 if rank == 1 and i != 1:
198 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
199 elif i != 1:
200 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
201
Eric Kunzee5e26762020-10-13 16:11:07 -0700202 return shape_list
203
204 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100205 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800206 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700207
Matthew Haddon848efb42021-09-09 12:30:53 +0100208 if error_name != ErrorIf.WrongRank:
209 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700210
211 shape = testGen.makeShape(rank)
212
213 # Constrict the batch size?
214 if testGen.args.max_batch_size:
215 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100216
217 # Constrict dimension size for large ranks when creating WrongRank tests
218 shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700219
220 shape_list = []
221 for i in range(pl + const):
222 shape_list.append(shape.copy())
223
224 return shape_list
225
226 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100227 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800228 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800229
Kevin Cheng550ccc52021-03-03 11:21:43 -0800230 assert pl == 2
231 assert const == 0
232 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800233
234 values_in_shape = testGen.makeShape(rank)
235
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100236 # ignore max batch size if target shape is set
237 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800238 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
239
Kevin Cheng550ccc52021-03-03 11:21:43 -0800240 W = testGen.randInt(
241 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
242 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100243 # Constrict W if one dimension is too large to keep tensor size reasonable
244 if max(values_in_shape) > 5000:
245 W = testGen.randInt(0, 16)
246
Kevin Cheng77d0f762020-11-24 10:26:32 -0800247 input_shape = [values_in_shape[0], W, values_in_shape[2]]
248
249 shape_list = []
250 shape_list.append(values_in_shape.copy())
251 shape_list.append(input_shape.copy())
252
253 return shape_list
254
255 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100256 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700257 shape = testGen.makeShape(rank)
258
Kevin Cheng550ccc52021-03-03 11:21:43 -0800259 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700260
261 shape_list = []
262
263 # Choose one of the inputs to broadcast
264 bcast_idx = testGen.randInt(0, pl + const)
265 for i in range(pl + const):
266 shape_bcast = shape.copy()
267
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100268 if error_name == ErrorIf.RankMismatch:
269 bcast_idx = -1 # Turn off broadcast because we are not testing it
270 if rank == 1 and i != 1:
271 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
272 elif i != 1:
273 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
274
Eric Kunzee5e26762020-10-13 16:11:07 -0700275 # If the chosen input, pick a random index to broadcast
276 if i == bcast_idx:
277 fuzz_idx = testGen.randInt(0, rank)
278 shape_bcast[fuzz_idx] = 1
279
280 shape_list.append(shape_bcast)
281
282 return shape_list
283
284 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100285 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800286 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700287
Kevin Cheng550ccc52021-03-03 11:21:43 -0800288 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700289
290 # IFM dimensions are NHWC
291 ifm_shape = testGen.makeShape(rank)
292
293 # Constrict the batch size?
294 if testGen.args.max_batch_size:
295 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
296
297 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800298 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700299
300 # Generate a random OFM depth
301 ofm_depth = testGen.makeShape(1)[0]
302
303 # The filter dimensions are OHWI
304 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
305
306 # The bias is OC
307 bias_shape = np.asarray([ofm_depth])
308
309 return [ifm_shape, filter_shape, bias_shape]
310
311 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100312 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700313 pl, const = op["operands"]
314
315 assert rank == 5
316
317 # IFM dimensions are NDHWC
318 ifm_shape = testGen.makeShape(rank)
319
320 # Constrict the batch size?
321 if testGen.args.max_batch_size:
322 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
323
324 # Get the filter depth/height/width from the operator parameters
325 filter_dhw = op["filter"]
326
327 # Generate a random OFM channel
328 ofm_channel = testGen.makeShape(1)[0]
329
330 # The filter dimensions are ODHWI
331 filter_shape = np.asarray(
332 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
333 )
334
335 # The bias is OC
336 bias_shape = np.asarray([ofm_channel])
337
338 return [ifm_shape, filter_shape, bias_shape]
339
340 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100341 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800342 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700343
Kevin Cheng550ccc52021-03-03 11:21:43 -0800344 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700345
346 # IFM dimensions are NHWC
347 ifm_shape = testGen.makeShape(rank)
348
349 # Constrict the batch size?
350 if testGen.args.max_batch_size:
351 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
352
353 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800354 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700355
356 # Generate a random OFM depth
357 ofm_depth = testGen.makeShape(1)[0]
358
359 # The filter dimensions are OHWI
360 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
361
Kevin Cheng989cb052021-04-28 16:29:44 -0700362 # The bias is OC
363 bias_shape = np.asarray([ofm_depth])
364
365 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700366
367 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100368 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800369 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700370
Kevin Cheng550ccc52021-03-03 11:21:43 -0800371 assert rank == 4
372 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700373
374 # IFM dimensions are NHWC
375 ifm_shape = testGen.makeShape(rank)
376
377 # Constrict the batch size?
378 if testGen.args.max_batch_size:
379 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
380
381 # Get the filter height/width from the operator parameters
382 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800383 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700384
385 # Generate a random OFM depth, but don't let it get too big because
386 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800387 filter_m = (
388 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
389 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700390
391 # The filter dimensions are HWCM
392 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
393
394 # The bias is M * C
395 bias_shape = np.asarray([ifm_shape[3] * filter_m])
396
397 return [ifm_shape, filter_shape, bias_shape]
398
399 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100400 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800401 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700402
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100403 if error_name != ErrorIf.WrongRank:
404 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700405
406 input_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100407
408 # Constrict dimension size for large ranks when creating WrongRank tests
409 shape = TosaErrorIfArgGen.eiRestrictDimension(input_shape, error_name)
410
Kevin Chengacb550f2021-06-29 15:32:19 -0700411 filter_oc = testGen.rng.integers(
412 low=testGen.args.tensor_shape_range[0],
413 high=testGen.args.tensor_shape_range[1],
414 size=1,
415 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700416 filter_shape = np.asarray([filter_oc, input_shape[1]])
417
418 bias_shape = np.asarray([filter_oc])
419
420 return [input_shape, filter_shape, bias_shape]
421
422 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100423 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800424 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700425
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100426 if error_name != ErrorIf.WrongRank:
427 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800428 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700429
430 a_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100431
432 # Constrict dimension size for large ranks when creating WrongRank tests
433 shape = TosaErrorIfArgGen.eiRestrictDimension(a_shape, error_name)
434
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100435 # Get a random number for b_oc even if target shape is defined
436 b_oc = np.int32(
437 testGen.rng.integers(
438 low=testGen.args.tensor_shape_range[0],
439 high=testGen.args.tensor_shape_range[1],
440 size=1,
441 )
442 )[0]
443 # If N or H is large let b_oc be 1 to reduce output tensor size
444 if max(a_shape) > 1000:
445 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700446
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100447 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700448 return [a_shape, b_shape]
449
Matthew Haddon818ab902021-07-27 09:12:49 +0100450 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100451 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100452 pl, const = opName["operands"]
453 shape = testGen.makeShape(rank)
454
455 # Create extra tensors to concat.
456 # Take into account value of pl when getting maximum number of concats
457 num_tensors = testGen.randInt(0, 4)
458 shape_list = []
459 for i in range(pl + const + num_tensors):
460 shape_list.append(shape.copy())
461
462 return shape_list
463
464 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100465 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100466 # Split concat shape along axis to allow for multiple const inputs
467 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100468 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddon818ab902021-07-27 09:12:49 +0100469 return shapeList
470
Jeremy Johnson960985a2021-10-06 10:58:14 +0100471 # Create copy of shape we are going to split (so we don't alter shapeList)
472 shape = shapeList[0].copy()
473 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100474 new_shapeList = [shape.copy()]
475 length_on_axis = shape[axis]
476 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700477 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100478 # Calculate split on axis and remaining value
479 split_shape_val = int(shape[axis] / 2)
480 remaining_length = remaining_length - split_shape_val
481
482 # Append new shape, and set remaining shape
483 shape[axis] = split_shape_val
484 new_shapeList.append(shape.copy())
485 shape[axis] = remaining_length
486 if i == len(shapeList) - 3:
487 new_shapeList.append(shape.copy())
488
489 return new_shapeList
490
491
Eric Kunzee5e26762020-10-13 16:11:07 -0700492class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800493 """Argument generators create exhaustive or random lists of attributes for operators that take
494 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
495 tuples where the descriptive_name is appended to the test name and the arglist is expanded
496 as arguments to the operator build function."""
497
Eric Kunzee5e26762020-10-13 16:11:07 -0700498 def __init__(self):
499 pass
500
501 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100502 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800503 """A trivial argument generator for operators that don't take any
504 non-tensor arguments"""
505 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700506
507 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100508 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800509 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700510 axes = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700511 shape = shapeList[0]
512
Matthew Haddond6ce7252021-09-29 15:35:44 +0100513 if error_name == ErrorIf.AxisSmallerZero:
514 small_axis = testGen.rng.integers(-5, 0)
515 axes.append(("axis{}".format(small_axis), [small_axis]))
516 elif error_name == ErrorIf.AxisLargerRank:
517 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
518 axes.append(("axis{}".format(large_axis), [large_axis]))
519 else:
520 for a in range(0, len(shape)):
521 axes.append(("axis{}".format(a), [a]))
522
Eric Kunzee5e26762020-10-13 16:11:07 -0700523 return axes
524
525 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100526 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700527 arg_list = []
528
529 ifm_shape = shapeList[0]
530 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100531 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
532 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700533
Les Bell7aa69f42021-09-20 10:44:07 +0100534 # Check the rank
535 rank = 5 if opName.startswith("conv3d") else 4
536 assert len(ifm_shape) == rank
537 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700538
Les Bell7aa69f42021-09-20 10:44:07 +0100539 # kernel rank omits batch and channels
540 k_rank = rank - 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700541
Les Bell7aa69f42021-09-20 10:44:07 +0100542 # Generate comprehensive argument lists
543 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
544 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
545 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
546 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
547 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
548 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700549
Les Bell7aa69f42021-09-20 10:44:07 +0100550 # add some oversize argument values
551 if max(ifm_shape) < 64:
552 bigPadding = 9
553 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
554 bigStride = 8
555 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
556 bigDilation = 7
557 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100558
559 # There are too many parameter combinations, so generate them sparsely
Les Bell7aa69f42021-09-20 10:44:07 +0100560 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
561 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
562 if sparsity < 13:
563 sparsity = 1
564 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
565 sparsity += 1
Les Bellf414b3c2021-09-06 11:29:46 +0100566 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100567 for s in sorted(list(strides)):
568 for p in sorted(list(paddings)):
569 for d in sorted(list(dilations)):
570 if (n % sparsity == 0
571 # padding must not exceed the kernel size ?
572 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
573 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
574 # the padded shape must exceed the kernel size
575 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
576 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
577 # the padded shape must exceed the dilation
578 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
579 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
580 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100581 arg_list.append(
582 (
583 "st{}_pad{}_dilat{}".format(
584 "".join([str(x) for x in s]),
585 "".join([str(x) for x in p]),
586 "".join([str(x) for x in d]),
587 ),
588 [s, p, d],
589 )
590 )
591 n += 1
592
Kevin Cheng1533b852021-09-01 12:51:58 -0700593 return arg_list
594
595 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100596 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700597 arg_list = []
598
599 ifm_shape = shapeList[0]
600 filter_shape = shapeList[1]
601
602 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800603 assert len(ifm_shape) == 4
604 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700605
Les Bell7aa69f42021-09-20 10:44:07 +0100606 # Generate comprehensive argument lists
607 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
608 paddings = {x for x in itertools.product(*([p_vals] * 2))}
609 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
610 strides = {x for x in itertools.product(*([s_vals] * 2))}
611 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
612 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700613
Les Bell7aa69f42021-09-20 10:44:07 +0100614 # add some oversize argument values
615 if max(ifm_shape) < 64:
616 bigPadding = 9
617 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
618 bigStride = 8
619 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
620 bigDilation = 7
621 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700622
Les Bell7aa69f42021-09-20 10:44:07 +0100623 # There are too many parameter combinations, so generate them sparsely
624 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
625 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
626 if sparsity < 13:
627 sparsity = 1
628 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
629 sparsity += 1
630 n = 0
631 for s in sorted(list(strides)):
632 for p in sorted(list(paddings)):
633 for d in sorted(list(dilations)):
634 if n % sparsity == 0:
635 # Determine the output shape
636 oh = (
637 ifm_shape[1]
638 - filter_shape[1]
639 - (filter_shape[1] - 1) * (d[0] - 1)
640 + 2 * p[0]
641 ) // s[0] + 1
642 ow = (
643 ifm_shape[2]
644 - filter_shape[2]
645 - (filter_shape[2] - 1) * (d[1] - 1)
646 + 2 * p[1]
647 ) // s[1] + 1
648 os = [ifm_shape[0], oh, ow, filter_shape[0]]
649 arg_list.append(
650 (
651 "st{}_pad{}_dilat{}_os{}".format(
652 "".join([str(x) for x in s]),
653 "".join([str(x) for x in p]),
654 "".join([str(x) for x in d]),
655 "x".join([str(x) for x in os]),
656 ),
657 [s, p, d, os],
658 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800659 )
Les Bell7aa69f42021-09-20 10:44:07 +0100660 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700661
662 return arg_list
663
664 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100665 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700666 arg_list = []
667 rank = len(shapeList[0])
668
Les Bell7ffccce2021-07-28 15:37:02 +0100669 # Exhaustively test combinations of padding on each side of each dimension
670 # - the range of padding values is defined by pad_min and pad_max
671 # - for padding >9, the name format needs to be more distinctive
672 pad_min, pad_max = 0, 1
673 pad_values = [x for x in range(pad_min, pad_max + 1)]
Matthew Haddone807aae2021-10-11 18:12:58 +0100674 if error_name == ErrorIf.PadSmallerZero:
675 pad_values = [x for x in range(-2, 0)]
Les Bell7ffccce2021-07-28 15:37:02 +0100676 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
677 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700678
Kevin Chengfe392ce2021-10-18 21:51:55 +0000679 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
680 pad_const_int = testGen.getRandNumberDType(dtype)
681 pad_const_fp = 0
682 elif dtype == DType.FLOAT:
683 pad_const_int = 0
684 pad_const_fp = testGen.getRandNumberDType(dtype)
685 else:
686 return []
687
Les Bell7ffccce2021-07-28 15:37:02 +0100688 for paddings in shape_pad_values:
689 name = "pad"
690 for r in range(rank):
691 before, after = paddings[r]
692 name = f"{name}{before}{after}"
Kevin Chengfe392ce2021-10-18 21:51:55 +0000693 arg_list.append((name, [np.array(paddings), pad_const_int, pad_const_fp]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700694
695 return arg_list
696
697 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100698 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700699 arg_list = []
700
701 shape = shapeList[0]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100702 if error_name != ErrorIf.WrongRank:
703 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700704
Les Bell7aa69f42021-09-20 10:44:07 +0100705 # Generate comprehensive argument lists
706 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
707 paddings = {x for x in itertools.product(*([p_vals] * 4))}
708 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
709 strides = {x for x in itertools.product(*([s_vals] * 2))}
710 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
711 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700712
Les Bell7aa69f42021-09-20 10:44:07 +0100713 # add some oversize argument values
714 bigStride = 7
715 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
716 bigKernel = 6
717 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
718 if max(shape) < 64:
719 # padding must be less than the kernel size
720 bigPadding = bigKernel - 1
721 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700722
Les Bell7aa69f42021-09-20 10:44:07 +0100723 # There are too many parameter combinations, so generate them sparsely
724 sparsity = len(paddings) * len(strides) * len(kernels) // 500 + 1
725 n = 0
726 for s in sorted(list(strides)):
727 for p in sorted(list(paddings)):
728 for k in sorted(list(kernels)):
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100729 if error_name in [ErrorIf.StrideSmallerOne, ErrorIf.KernelSmallerOne, ErrorIf.PadSmallerZero, ErrorIf.PadLargerEqualKernel]:
730 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(testGen, error_name, s, p, k)
731 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
732 arg_list.append(
733 (
734 "st{}_kern{}_pad{}".format(
735 "".join([str(x) for x in sNew]),
736 "".join([str(x) for x in kNew]),
737 "".join([str(x) for x in pNew]),
738 ),
739 [sNew, pNew, kNew],
740 )
741 )
742 elif (n % sparsity == 0
Les Bell7aa69f42021-09-20 10:44:07 +0100743 # padding must not exceed the kernel size
744 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
745 # the padded shape must exceed the kernel size
746 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
747 ):
748 arg_list.append(
749 (
750 "st{}_kern{}_pad{}".format(
751 "".join([str(x) for x in s]),
752 "".join([str(x) for x in k]),
753 "".join([str(x) for x in p]),
754 ),
755 [s, p, k],
756 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800757 )
Les Bell7aa69f42021-09-20 10:44:07 +0100758 n += 1
759
Eric Kunzee5e26762020-10-13 16:11:07 -0700760 return arg_list
761
762 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100763 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700764 arg_list = []
765
766 # Enumerate the output types here
767 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800768 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700769 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800770 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700771 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800772 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700773 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800774 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700775 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800776 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700777 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800778 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700779
780 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800781 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700782
783 return arg_list
784
785 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100786 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700787 arg_list = []
788
789 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100790 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100791 if dtype in [DType.UINT8, DType.INT8] and error_name == ErrorIf.OutputZeroPointNotZero:
792 continue
793 if inDtype == DType.UINT8 and dtype != DType.INT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100794 # The only output dtype for UINT8 is INT8, skip all other combinations
795 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100796 if inDtype != DType.INT8 and dtype == DType.UINT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100797 # The only input dtype for UINT8 is INT8, skip all other combinations
798 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100799 if error_name == ErrorIf.WrongOutputType and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype):
800 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100801
Kevin Cheng550ccc52021-03-03 11:21:43 -0800802 for scale32 in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100803 if error_name == ErrorIf.ScaleTrue and scale32 == False:
804 continue
805 elif error_name == ErrorIf.ScaleNotTrue and scale32 == True:
806 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800807 for double_round in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100808 if error_name == ErrorIf.ScaleNotTrue and double_round == False:
809 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800810 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700811
Matthew Haddonc2025212021-10-08 21:21:05 +0100812 if inDtype == DType.INT48 and scale32 and error_name != ErrorIf.ScaleTrue:
Eric Kunzee5e26762020-10-13 16:11:07 -0700813 # Illegal condition. Must be scale32=False
814 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100815 if double_round and not scale32 and error_name != ErrorIf.ScaleNotTrue:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100816 # Illegal condition. ERROR_IF(!scale32 && double_round)
817 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700818
Kevin Cheng550ccc52021-03-03 11:21:43 -0800819 arg_list.append(
820 (
821 "out{}_sc{}_dr{}_pc{}".format(
822 DTypeNames[dtype],
823 int(scale32),
824 int(double_round),
825 int(per_channel),
826 ),
827 [dtype, scale32, double_round, per_channel],
828 )
829 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700830
831 return arg_list
832
Kevin Chengaee1fac2020-11-11 13:54:06 -0800833 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100834 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800835 arg_list = []
836
837 if dtype is DType.INT32:
838 for p in range(testGen.args.num_rand_permutations):
839
840 shift = testGen.randInt(0, 32)
841
Kevin Cheng550ccc52021-03-03 11:21:43 -0800842 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800843 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100844 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800845
846 return arg_list
847
848 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100849 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800850 arg_list = []
851
Kevin Cheng550ccc52021-03-03 11:21:43 -0800852 arg_list.append(("roundTrue", [True]))
853 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800854
855 return arg_list
856
Eric Kunzee5e26762020-10-13 16:11:07 -0700857 # Helper function for reshape. Gets some factors of a larger number.
858 @staticmethod
859 def getFactors(val, start=1):
860 factors = []
861
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100862 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700863 if (val % i) == 0:
864 factors.append(i)
865
866 return factors
867
868 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100869 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700870 arg_list = []
871
872 origShape = shapeList[0]
873
874 totalElements = 1
875 for s in origShape:
876 totalElements *= s
877
878 # This code is NOT fast. Fortunately, the numbers are fairly small.
879 factors = TosaArgGen.getFactors(totalElements)
880
881 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100882 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800883 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700884 continue
885
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100886 found = True
887 # escape_counter breaks while loop if it continues on for too long
888 escape_counter = 0
889 while found:
890 newShape = []
891 # Generate newShape ensuring it isn't a duplicate
892 remainingElements = totalElements
893 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100894 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100895 # pick rank-1 factors
896 newShape.append(shuffledFactors[0])
897 remainingElements = remainingElements // shuffledFactors[0]
898 shuffledFactors = testGen.rng.permutation(
899 TosaArgGen.getFactors(remainingElements)
900 )
901 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700902
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100903 # Toss in a -1 sometimes
904 minusOne = testGen.randInt(0, newRank * 4)
905 if minusOne < newRank:
906 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700907
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100908 # Check for duplicates
909 found = False
910 for name, other_shape in arg_list:
911 if other_shape[0] == newShape:
912 found = True
913 break
914
915 escape_counter += 1
916 if escape_counter >= 100:
917 break
918
919 if not found:
920 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700921
922 return arg_list
923
Eric Kunzee5e26762020-10-13 16:11:07 -0700924 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100925 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700926 arg_list = []
927
928 ifm_shape = shapeList[0]
929
Matthew Haddone807aae2021-10-11 18:12:58 +0100930
931 if error_name == ErrorIf.IndexOutsideBounds:
932 incorrect_large_index = range(len(ifm_shape)+1, 2*len(ifm_shape)+1)
933 incorrect_small_index = range(-len(ifm_shape), 0)
934 permutations = [p for p in itertools.permutations(incorrect_large_index)]
935 permutations.extend([p for p in itertools.permutations(incorrect_small_index)])
936 elif error_name == ErrorIf.IndexUsedTwice:
937 # Create list with a duplicated index
938 perm_range = list(range(len(ifm_shape)))
939 index_choice = testGen.rng.choice(range(len(perm_range)))
940 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
941 permutations = [p for p in itertools.permutations(perm_range)]
942
943
944 else:
945 # Get all permutations
946 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700947
Jeremy Johnsona6185572021-06-21 15:55:35 +0100948 # Limit to possible permutations from shape dimension or argument setting
949 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700950
Jeremy Johnsona6185572021-06-21 15:55:35 +0100951 # Get random permutation generator that uses all permutations
952 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700953
Jeremy Johnsona6185572021-06-21 15:55:35 +0100954 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -0700955 arg_list = [
956 ("perm{}".format(p), [random_permutations[p].tolist()])
957 for p in range(limit)
958 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700959 return arg_list
960
961 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100962 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700963 arg_list = []
964
965 ifm_shape = shapeList[0]
966 rank = len(ifm_shape)
967
968 for p in range(testGen.args.num_rand_permutations):
Matthew Haddone807aae2021-10-11 18:12:58 +0100969 start = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700970 size = []
971
Kevin Cheng550ccc52021-03-03 11:21:43 -0800972 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700973
974 for i in range(rank):
975 if ifm_shape[i] > 1:
Matthew Haddone807aae2021-10-11 18:12:58 +0100976 start.append(testGen.randInt(0, ifm_shape[i]))
977 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700978
979 # Invalid slice size?
980 if size[i] == 0:
981 valid = False
982 else:
Matthew Haddone807aae2021-10-11 18:12:58 +0100983 start.append(0)
Eric Kunzee5e26762020-10-13 16:11:07 -0700984 size.append(1)
985
986 if valid:
Matthew Haddone807aae2021-10-11 18:12:58 +0100987 # If ERROR_IF test required then incorrect start, size will be returned
988 start, size = TosaErrorIfArgGen.eiSliceErrorIf(testGen, error_name, ifm_shape, start, size)
989 arg_list.append(("perm{}".format(p), [start, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700990 return arg_list
991
992 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100993 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700994 arg_list = []
995
996 ifm_shape = shapeList[0]
997 rank = len(ifm_shape)
998
999 for p in range(testGen.args.num_rand_permutations):
1000
1001 # Pick a few random, but small multiple values
1002 # because otherwise this has a tendency to generate
1003 # enormous tensors
1004 multiples = []
1005 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +01001006 if ifm_shape[i] > 1000:
1007 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
1008 multiples.append(1)
1009 elif max(ifm_shape) > 1000:
1010 multiples.append(2)
1011 else:
1012 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001013 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001014
1015 return arg_list
1016
1017 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001018 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001019 arg_list = []
1020
1021 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +01001022 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001023
1024 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +01001025 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001026 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +01001027 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001028 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +01001029 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001030 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +01001031 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001032 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001033 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001034 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +01001035 elif error_name == ErrorIf.WrongInputType:
1036 # If an incorrect input type is used then we set a 'correct'
1037 # output type to avoid other errors
1038 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -07001039 else:
1040 continue
1041
1042 for outputDType in outputDTypeList:
1043 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -07001044 # Randomly generate legal output dimensions and shift
1045 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +01001046 # A output_dim of 1 will cause offset to exceed allowed range
1047 # so minimum value 2 produced below
1048 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
1049 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
1050 output_dims[0] += 1
1051 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
1052 output_dims[1] += 1
1053
Kevin Cheng77d0f762020-11-24 10:26:32 -08001054 in_center_h = (ifm_shape[1] - 1) / 2.0
1055 in_center_w = (ifm_shape[2] - 1) / 2.0
1056 out_center_h = (output_dims[0] - 1) / 2.0
1057 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -07001058
Kevin Cheng77d0f762020-11-24 10:26:32 -08001059 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
1060 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
1061 fp_offset_y = in_center_h - fp_stride_y * out_center_h
1062 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -07001063
Kevin Cheng77d0f762020-11-24 10:26:32 -08001064 if outputDType == DType.FLOAT:
1065 shift = 0
1066 stride = [0, 0]
1067 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001068 stride_fp = [fp_stride_y, fp_stride_x]
1069 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +01001070
1071 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001072 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001073 testGen,
1074 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001075 mode,
1076 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001077 shapeList,
1078 outputDType,
1079 shift,
1080 stride,
1081 stride_fp,
1082 offset,
1083 offset_fp
1084 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001085 else:
1086 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001087
Kevin Cheng550ccc52021-03-03 11:21:43 -08001088 arg_list.append(
1089 (
1090 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001091 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001092 output_dims[0],
1093 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001094 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001095 stride_fp[0],
1096 stride_fp[1],
1097 offset_fp[0],
1098 offset_fp[1],
1099 ),
1100 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001101 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001102 stride,
1103 offset,
1104 shift,
1105 stride_fp,
1106 offset_fp,
1107 output_dims,
1108 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001109 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001110 ],
1111 )
1112 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001113 else:
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001114 shift = testGen.randInt(1,12)
1115 # Now search for a shift value (1 to 11) that will produce
1116 # a valid and predictable resize operation
1117 count = 0
1118 while (count < 12):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001119 unit = float(1 << shift)
1120 stride_y = int(round(fp_stride_y * unit))
1121 stride_x = int(round(fp_stride_x * unit))
1122 offset_y = int(round(fp_offset_y * unit))
1123 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001124
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001125 if (
1126 stride_y >= (16 << shift)
1127 or stride_x >= (16 << shift)
1128 or offset_y >= (16 << shift)
1129 or offset_x >= (16 << shift)
1130 or offset_y <= (-16 << shift)
1131 or offset_x <= (-16 << shift)
1132 ):
1133 # Change the shift value and check again
1134 count += 1
1135 shift = (shift % 11) + 1
1136 continue
1137
1138 def RESIZE_REQUIRE_CALC(length_in, length_out, stride, offset, shift):
1139 # Perform the pseudo loop to look for out of bounds
1140 for pos in range(0,length_out):
1141 a = pos * stride + offset
1142 ia = a >> shift
1143 ia0 = max(ia, 0)
1144 ia1 = min(ia+1, length_in-1)
1145 if ia0 > ia1:
1146 # Found a problem value
1147 break
1148 return ia0, ia1
1149
1150 iy0, iy1 = RESIZE_REQUIRE_CALC(ifm_shape[1], output_dims[0], stride_y, offset_y, shift)
1151 ix0, ix1 = RESIZE_REQUIRE_CALC(ifm_shape[2], output_dims[1], stride_x, offset_x, shift)
1152 if ix0 > ix1 or iy0 > iy1:
1153 # Change the shift value and check again
1154 count += 1
1155 shift = (shift % 11) + 1
1156 continue
1157 break
1158
1159 if count >= 12:
1160 # Couldn't find a good set of values for this test, skip it
1161 continue
1162
Kevin Cheng550ccc52021-03-03 11:21:43 -08001163 stride = [stride_y, stride_x]
1164 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001165
1166 stride_fp = [0.0, 0.0]
1167 offset_fp = [0.0, 0.0]
1168
Matthew Haddone86fd342021-09-07 16:12:21 +01001169 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001170 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001171 testGen,
1172 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001173 mode,
1174 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001175 shapeList,
1176 outputDType,
1177 shift,
1178 stride,
1179 stride_fp,
1180 offset,
1181 offset_fp
1182 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001183 else:
1184 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001185
Kevin Cheng550ccc52021-03-03 11:21:43 -08001186 arg_list.append(
1187 (
1188 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001189 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001190 shift,
1191 output_dims[0],
1192 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001193 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001194 stride[0],
1195 stride[1],
1196 offset[0],
1197 offset[1],
1198 ),
1199 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001200 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001201 stride,
1202 offset,
1203 shift,
1204 stride_fp,
1205 offset_fp,
1206 output_dims,
1207 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001208 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001209 ],
1210 )
1211 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001212
1213 return arg_list
1214
Kevin Chengfe392ce2021-10-18 21:51:55 +00001215 @staticmethod
1216 def agTable(testGen, opName, shapeList, dtype, error_name=None):
1217 arg_list = []
1218
1219 if dtype == DType.INT8:
1220 table = np.int32(
1221 testGen.rng.integers(low=-128, high=128, size=[256])
1222 ).tolist()
1223 else: # INT16
1224 table = np.int32(
1225 testGen.rng.integers(low=-32768, high=32768, size=[513])
1226 ).tolist()
1227
1228 arg_list.append(
1229 (
1230 "",
1231 [table],
1232 )
1233 )
1234 return arg_list
1235
Matthew Haddon1c00b712021-10-01 15:51:03 +01001236 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001237 # CondIf generates the condition values here.
1238 # Convert to tensors in the build function, along with the
1239 # then and else blocks
1240 arg_list = []
1241
1242 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001243 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001244
1245 return arg_list
1246
Matthew Haddon1c00b712021-10-01 15:51:03 +01001247 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001248 # While loop: 0 iterations, 1, more than 1
1249 arg_list = []
1250
1251 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001252 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001253
1254 return arg_list
1255
Matthew Haddone86fd342021-09-07 16:12:21 +01001256class TosaErrorIfArgGen:
1257
1258 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001259 def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
Matthew Haddone86fd342021-09-07 16:12:21 +01001260
1261 if outputDType == DType.FLOAT:
1262 if error_name == ErrorIf.StrideSmallerEqualZero:
1263 stride_fp = testGen.rng.random(size=[2]) - 2
1264 elif error_name == ErrorIf.ShiftNotZero:
1265 shift = testGen.rng.integers(1, 5)
1266 elif error_name == ErrorIf.StrideLargerDimension:
1267 shape = shapeList[0]
1268 transform_height = testGen.rng.choice([False, True])
1269 if transform_height:
1270 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1271 else:
1272 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1273 else:
1274 if error_name == ErrorIf.StrideSmallerEqualZero:
1275 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1276 elif error_name == ErrorIf.ShiftSmallerOne:
1277 shift = testGen.rng.integers(-3, 1)
1278 if shift <= 0:
1279 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1280 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1281 else:
1282 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1283 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1284 elif error_name == ErrorIf.ShiftLargerEleven:
1285 shift = np.int16(testGen.rng.integers(12, 15))
1286 elif error_name == ErrorIf.StrideLargerDimension:
1287 shape = shapeList[0]
1288 transform_height = testGen.rng.choice([False, True])
1289 if transform_height:
1290 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1291 else:
1292 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1293 elif error_name == ErrorIf.StrideLargerEqualMax:
1294 stride = [(16 << shift) + 1, (16 << shift) + 1]
1295 elif error_name == ErrorIf.OffsetLargerEqualMax:
1296 offset = [(16 << shift) + 1, (16 << shift) + 1]
1297 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1298 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1299
Matthew Haddon1c00b712021-10-01 15:51:03 +01001300
Matthew Haddon848efb42021-09-09 12:30:53 +01001301 if error_name == ErrorIf.WrongOutputType:
1302 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1303 incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1304 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1305 incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
1306 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1307 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
1308 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1309 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
1310 elif dtype == DType.FLOAT:
1311 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
1312 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001313
Matthew Haddon848efb42021-09-09 12:30:53 +01001314 return shift, stride, stride_fp, offset, offset_fp, outputDType
1315
Matthew Haddone807aae2021-10-11 18:12:58 +01001316
Matthew Haddon848efb42021-09-09 12:30:53 +01001317 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001318 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
1319 if (error_name == ErrorIf.StrideSmallerOne
1320 # padding must not exceed the kernel size
1321 and pad[0] < kernel[0] and pad[1] < kernel[0] and pad[2] < kernel[1] and pad[3] < kernel[1]):
1322 wrongStride = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1323 return wrongStride, pad, kernel
1324 elif error_name == ErrorIf.PadSmallerZero:
1325 wrongPad = (testGen.rng.choice([-1, -2, -3]),
1326 testGen.rng.choice([-1, -2, -3]),
1327 testGen.rng.choice([-1, -2, -3]),
1328 testGen.rng.choice([-1, -2, -3]))
1329 return stride, wrongPad, kernel
1330 elif error_name == ErrorIf.KernelSmallerOne:
1331 wrongKernel = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1332 return stride, pad, wrongKernel
1333 elif error_name == ErrorIf.PadLargerEqualKernel:
1334 wrongPad = (testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1335 testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1336 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]),
1337 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]))
1338 return stride, wrongPad, kernel
1339 else:
1340 return None, None, None
1341
Matthew Haddone807aae2021-10-11 18:12:58 +01001342
Matthew Haddonc2025212021-10-08 21:21:05 +01001343 @staticmethod
1344 def eiRescaleWrongOutputType(input_dtype, output_dtype):
1345 if input_dtype == DType.INT8:
1346 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1347 return True
1348 if input_dtype in [DType.INT16, DType.INT32]:
1349 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1350 return True
1351 elif input_dtype == DType.INT48:
1352 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1353 return True
1354 elif input_dtype == DType.UINT8:
1355 if output_dtype != DType.INT8:
1356 return True
1357 return False
1358
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001359
1360 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001361 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1362 # Mess up input/output tensors for ERROR_IF checks
1363 if error_name == "WrongInputList":
1364 add_input = testGen.rng.choice([True, False])
1365 if add_input:
1366 input_list.append('eiDummyInput')
1367 else:
1368 input_list = input_list[:-1]
1369 if error_name == "WrongOutputList":
1370 add_output = testGen.rng.choice([True, False])
1371 if add_output:
1372 output_list.append('eiDummyOutput')
1373 else:
1374 output_list = []
1375 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001376
Matthew Haddone807aae2021-10-11 18:12:58 +01001377
Matthew Haddonc2025212021-10-08 21:21:05 +01001378 @staticmethod
1379 def eiRestrictDimension(shape, error_name):
1380 # Restrict dimension size if rank is large for WrongRank Error_If
1381 # This will keep the test sizes reasonably small
1382 if error_name == ErrorIf.WrongRank:
1383 if len(shape) > 4:
1384 shape[4] = 1
1385
1386 return shape
1387
Matthew Haddone807aae2021-10-11 18:12:58 +01001388
1389 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
1390 if error_name == ErrorIf.StartSmallerZero:
1391 newStart = []
1392 for i in range(len(input_shape)):
1393 newStart.append(testGen.rng.choice([-3, -2, -1]))
1394 return newStart, size
1395 elif error_name == ErrorIf.SizeSmallerEqualZero:
1396 newSize = []
1397 for i in range(len(input_shape)):
1398 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
1399 return start, newSize
1400 elif error_name == ErrorIf.StartSizeOutsideBounds:
1401 newStart, newSize = [], []
1402 for i in range(len(input_shape)):
1403 newStart.append(input_shape[i]-1)
1404 newSize.append(testGen.rng.choice([2, 3, 4]))
1405 return newStart, newSize
1406 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
1407 remove = testGen.rng.choice([True, False])
1408 if remove:
1409 newStart = start[1:]
1410 newSize = size[1:]
1411 else:
1412 newStart = start
1413 newStart.append(1)
1414 newSize = size
1415 newSize.append(1)
1416 return newStart, newSize
1417 else:
1418 return start, size
1419
Matthew Haddone86fd342021-09-07 16:12:21 +01001420class TosaErrorValidator:
1421
Matthew Haddon848efb42021-09-09 12:30:53 +01001422 @staticmethod
1423 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
1424 # Check ERROR_IF statements
1425
1426 for val_fcn in validator_fcns:
1427 val_result = val_fcn(True, **kwargs)
1428
1429 validator_name = val_result['error_name']
1430 error_result = val_result['error_result']
1431 error_reason = val_result['error_reason']
1432
1433 if error_result:
1434 if error_name == validator_name:
1435 serializer.setExpectedReturnCode(2, error_reason)
1436 else:
1437 print(f"Multiple ERROR_IF checks hit \nError required: {error_name}, Error_produced: {validator_name}")
1438 return None # Return None to delete test if wrong ERROR_IF is hit
1439 else:
1440 if error_name == validator_name:
1441 print(f"No ERROR_IF hit for {error_name}")
1442 return None
1443
1444 @staticmethod
1445 def evWrongInputType(check=False, **kwargs):
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001446 all_dtypes = {DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
Matthew Haddon848efb42021-09-09 12:30:53 +01001447
1448 # Find the unsupported input data types
1449 assert 'op' in kwargs
1450 op = kwargs['op']
1451 input_dtypes = op['types']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001452
1453 allowed_input_dtypes = {t[0] if isinstance(t, list) else t for t in input_dtypes}
1454 wrong_input_dtypes = list(all_dtypes - allowed_input_dtypes)
Matthew Haddon848efb42021-09-09 12:30:53 +01001455
1456 error_name = ErrorIf.WrongInputType
1457 param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
1458 error_result = False
1459 error_reason = "Input data type not supported for this operator"
1460
1461 if check:
1462 input_dtype = kwargs['input_dtype']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001463 if op['op'] == Op.FULLY_CONNECTED:
1464 if input_dtype not in allowed_input_dtypes:
1465 error_result = True
1466 elif input_dtype not in input_dtypes:
Matthew Haddon848efb42021-09-09 12:30:53 +01001467 error_result = True
1468
1469 info_dict = {
1470 "error_name": error_name,
1471 "error_result": error_result,
1472 "error_reason": error_reason,
1473 "param_reqs": param_reqs
1474 }
1475 return info_dict
1476
1477 @staticmethod
1478 def evWrongOutputType(check=False, **kwargs):
1479 error_name = ErrorIf.WrongOutputType
1480 param_reqs = {"rank": None, "dtype": None, "shape": None}
1481 error_result = False
1482 error_reason = "Output data type not supported for this configuration of operator"
1483
1484 if check:
1485 input_dtype = kwargs['input_dtype']
1486 output_dtype = kwargs['output_dtype']
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001487 op = kwargs['op']
Matthew Haddon848efb42021-09-09 12:30:53 +01001488
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001489 if op['op'] == Op.RESIZE:
1490 mode = kwargs['mode']
1491 if (
1492 (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
1493 (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
1494 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1495 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1496 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1497 ):
1498 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001499 elif op['op'] == Op.RESCALE:
1500 if input_dtype == DType.INT8:
1501 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1502 error_result = True
1503 if input_dtype in [DType.INT16, DType.INT32]:
1504 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1505 error_result = True
1506 elif input_dtype == DType.INT48:
1507 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1508 error_result = True
1509 elif input_dtype == DType.UINT8:
1510 if output_dtype != DType.INT8:
1511 error_result = True
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001512 elif op['op'] in [Op.FULLY_CONNECTED, Op.MATMUL]:
1513 if (
1514 (input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1515 (input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1516 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1517 ):
1518 error_result = True
1519 elif op['op'] == Op.ARGMAX:
1520 if input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] and output_dtype != DType.INT32:
1521 error_result = True
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001522 else:
1523 if output_dtype != input_dtype:
1524 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001525
1526 info_dict = {
1527 "error_name": error_name,
1528 "error_result": error_result,
1529 "error_reason": error_reason,
1530 "param_reqs": param_reqs
1531 }
1532 return info_dict
1533
1534 @staticmethod
1535 def evWrongRank(check=False, **kwargs):
1536 all_ranks = (1, 2, 3, 4, 5)
1537
1538 # Make a list of incorrect ranks
1539 assert 'op' in kwargs
1540 op = kwargs['op']
1541 rmin, rmax = op['rank']
1542 rank_range = range(rmin, rmax + 1)
1543 incorrect_ranks = list(set(all_ranks) - set(rank_range))
Matthew Haddonc2025212021-10-08 21:21:05 +01001544 # Remove small incorrect ranks to avoid index errors
1545 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
Matthew Haddon848efb42021-09-09 12:30:53 +01001546 # Set minimum incorrect rank to 3 to avoid index error
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001547 if op['op'] in [Op.RESIZE]:
Matthew Haddon848efb42021-09-09 12:30:53 +01001548 incorrect_ranks = [3, 5]
1549
1550 error_name = ErrorIf.WrongRank
1551 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1552 error_result = False
1553 error_reason = "Rank not supported for this operator"
1554
1555 if check:
1556 input_shape = kwargs['input_shape']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001557
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001558 if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4:
Matthew Haddon848efb42021-09-09 12:30:53 +01001559 error_result = True
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001560 elif op['op'] == Op.FULLY_CONNECTED and len(input_shape) != 2:
1561 error_result = True
1562 elif op['op'] == Op.MATMUL and len(input_shape) != 3:
1563 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001564 else:
1565 if len(input_shape) not in rank_range:
1566 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001567
1568 info_dict = {
1569 "error_name": error_name,
1570 "error_result": error_result,
1571 "error_reason": error_reason,
1572 "param_reqs": param_reqs
1573 }
1574 return info_dict
1575
1576 @staticmethod
1577 def evWrongInputList(check=False, **kwargs):
1578 error_name = ErrorIf.WrongInputList
1579 param_reqs = {"rank": None, "dtype": None, "shape": None}
1580 error_result = False
1581 error_reason = "Op input list does not match expected input"
1582
1583 if check:
1584 op = kwargs['op']
1585 input_list = kwargs['input_list']
1586 num_operands = kwargs['num_operands']
Kevin Chengfe392ce2021-10-18 21:51:55 +00001587 if len(input_list) != num_operands:
1588 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001589
1590 info_dict = {
1591 "error_name": error_name,
1592 "error_result": error_result,
1593 "error_reason": error_reason,
1594 "param_reqs": param_reqs
1595 }
1596 return info_dict
1597
1598 @staticmethod
1599 def evWrongOutputList(check=False, **kwargs):
1600 error_name = ErrorIf.WrongOutputList
1601 param_reqs = {"rank": None, "dtype": None, "shape": None}
1602 error_result = False
1603 error_reason = "Op output list does not match expected output"
1604
1605 if check:
1606 output_list = kwargs['output_list']
1607 # Note this will be incorrect if an operator returns more than one output
1608 if len(output_list) != 1:
1609 error_result = True
1610
1611 info_dict = {
1612 "error_name": error_name,
1613 "error_result": error_result,
1614 "error_reason": error_reason,
1615 "param_reqs": param_reqs
1616 }
1617 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01001618
1619 @staticmethod
1620 def evMaxDimExceeded(check=False, **kwargs):
1621 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01001622 param_reqs = {
1623 "rank": [4,4],
1624 "dtype": [DType.INT8],
1625 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
1626 }
Matthew Haddone86fd342021-09-07 16:12:21 +01001627 error_result = False
1628 error_reason = "At least one maximum dimension is larger than 16384"
1629
1630 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001631 input_shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001632 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
1633 if ((input_shape[1] > 16384) or
1634 (input_shape[2] > 16384) or
1635 (output_shape[0] > 16384) or
1636 (output_shape[1] > 16384)):
1637 error_result = True
1638
1639 info_dict = {
1640 "error_name": error_name,
1641 "error_result": error_result,
1642 "error_reason": error_reason,
1643 "param_reqs": param_reqs
1644 }
1645 return info_dict
1646
1647 @staticmethod
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001648 def evBatchMismatch(check=False, **kwargs):
1649 error_name = ErrorIf.BatchMismatch
1650 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1651 error_result = False
1652 error_reason = "Input batch size not equal to output batch size"
1653
1654 assert 'op' in kwargs
1655 op = kwargs['op']
1656 rmin, rmax = op['rank']
1657 rank_range = range(rmin, rmax + 1)
1658
1659 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001660 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001661 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1662
1663 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
1664 error_result = True
1665
1666 info_dict = {
1667 "error_name": error_name,
1668 "error_result": error_result,
1669 "error_reason": error_reason,
1670 "param_reqs": param_reqs
1671 }
1672 return info_dict
1673
1674 @staticmethod
1675 def evChannelMismatch(check=False, **kwargs):
1676 error_name = ErrorIf.ChannelMismatch
1677 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1678 error_result = False
1679 error_reason = "Input channel size not equal to output channel size"
1680
1681 assert 'op' in kwargs
1682 op = kwargs['op']
1683 rmin, rmax = op['rank']
1684 rank_range = range(rmin, rmax + 1)
1685
1686 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001687 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001688 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1689 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
1690 error_result = True
1691
1692 info_dict = {
1693 "error_name": error_name,
1694 "error_result": error_result,
1695 "error_reason": error_reason,
1696 "param_reqs": param_reqs
1697 }
1698 return info_dict
1699
1700 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001701 def evStrideSmallerEqualZero(check=False, **kwargs):
1702 error_name = ErrorIf.StrideSmallerEqualZero
1703 param_reqs = {"rank": None, "dtype": None, "shape": None}
1704 error_result = False
1705 error_reason = "Stride value smaller than or equal zero"
1706
1707 if check:
1708 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001709 output_dtype = kwargs['output_dtype']
1710 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
1711 stride = kwargs['stride'] # Work around wrong input/output type tests
1712 elif output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001713 stride = kwargs['stride_fp']
Matthew Haddon848efb42021-09-09 12:30:53 +01001714 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1715 stride = kwargs['stride_fp'] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01001716 else:
1717 stride = kwargs['stride']
1718
1719 if min(stride) <= 0:
1720 error_result = True
1721
1722 info_dict = {
1723 "error_name": error_name,
1724 "error_result": error_result,
1725 "error_reason": error_reason,
1726 "param_reqs": param_reqs
1727 }
1728 return info_dict
1729
1730 @staticmethod
1731 def evStrideLargerEqualMax(check=False, **kwargs):
1732 error_name = ErrorIf.StrideLargerEqualMax
1733 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1734 error_result = False
1735 error_reason = "Stride value larger than or equal to maximum value"
1736
1737 if check:
1738 shift = kwargs['shift']
1739 input_dtype = kwargs['input_dtype']
1740 stride = kwargs['stride']
1741 if input_dtype in [DType.INT8, DType.INT16]:
1742 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1743 error_result = True
1744 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1745 error_result = True
1746
1747 info_dict = {
1748 "error_name": error_name,
1749 "error_result": error_result,
1750 "error_reason": error_reason,
1751 "param_reqs": param_reqs
1752 }
1753 return info_dict
1754
1755
1756 @staticmethod
1757 def evStrideLargerDimension(check=False, **kwargs):
1758 error_name = ErrorIf.StrideLargerDimension
1759 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1760 error_result = False
1761 error_reason = "Stride value larger than or equal to H/W dimension"
1762
1763 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001764 shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001765 input_dtype = kwargs['input_dtype']
1766 stride = kwargs['stride_fp']
1767
1768 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1769 error_result = True
1770
1771 info_dict = {
1772 "error_name": error_name,
1773 "error_result": error_result,
1774 "error_reason": error_reason,
1775 "param_reqs": param_reqs
1776 }
1777 return info_dict
1778
1779
1780 @staticmethod
1781 def evOffsetSmallerEqualMin(check=False, **kwargs):
1782 error_name = ErrorIf.OffsetSmallerEqualMin
1783 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1784 error_result = False
1785 error_reason = "Offset value smaller than or equal to minimum value"
1786
1787 if check:
1788 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001789 output_dtype = kwargs['output_dtype']
1790 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001791 offset = kwargs['offset_fp']
1792 else:
1793 offset = kwargs['offset']
1794
1795 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
1796 error_result = True
1797 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
1798 error_result = True
1799
1800 info_dict = {
1801 "error_name": error_name,
1802 "error_result": error_result,
1803 "error_reason": error_reason,
1804 "param_reqs": param_reqs
1805 }
1806 return info_dict
1807
1808 @staticmethod
1809 def evOffsetLargerEqualMax(check=False, **kwargs):
1810 error_name = ErrorIf.OffsetLargerEqualMax
1811 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1812 error_result = False
1813 error_reason = "Offset value larger than or equal to maximum value"
1814
1815 if check:
1816 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001817 output_dtype = kwargs['output_dtype']
1818 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001819 offset = kwargs['offset_fp']
1820 else:
1821 offset = kwargs['offset']
1822
1823 if shift >= 0:
1824 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
1825 error_result = True
1826
1827 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
1828 error_result = True
1829 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
1830 error_result = True
1831
1832 info_dict = {
1833 "error_name": error_name,
1834 "error_result": error_result,
1835 "error_reason": error_reason,
1836 "param_reqs": param_reqs
1837 }
1838 return info_dict
1839
1840 @staticmethod
1841 def evShiftNotZero(check=False, **kwargs):
1842 error_name = ErrorIf.ShiftNotZero
1843 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1844 error_result = False
1845 error_reason = "Shift value must be zero for float input"
1846
1847 if check:
1848 shift = kwargs['shift']
1849 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001850 output_dtype = kwargs['output_dtype']
1851 if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
Matthew Haddone86fd342021-09-07 16:12:21 +01001852 error_result = True
1853
1854 info_dict = {
1855 "error_name": error_name,
1856 "error_result": error_result,
1857 "error_reason": error_reason,
1858 "param_reqs": param_reqs
1859 }
1860 return info_dict
1861
1862
1863 @staticmethod
1864 def evShiftSmallerOne(check=False, **kwargs):
1865 error_name = ErrorIf.ShiftSmallerOne
1866 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1867 error_result = False
1868 error_reason = "Shift value smaller than one"
1869
1870 if check:
1871 shift = kwargs['shift']
1872 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001873 output_dtype = kwargs['output_dtype']
1874 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001875 error_result = True
1876
1877 info_dict = {
1878 "error_name": error_name,
1879 "error_result": error_result,
1880 "error_reason": error_reason,
1881 "param_reqs": param_reqs
1882 }
1883 return info_dict
1884
1885 @staticmethod
1886 def evShiftLargerEleven(check=False, **kwargs):
1887 error_name = ErrorIf.ShiftLargerEleven
1888 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1889 error_result = False
1890 error_reason = "Shift value larger than eleven"
1891
1892 if check:
1893 shift = kwargs['shift']
1894 if shift > 11:
1895 error_result = True
1896
1897 info_dict = {
1898 "error_name": error_name,
1899 "error_result": error_result,
1900 "error_reason": error_reason,
1901 "param_reqs": param_reqs
1902 }
1903 return info_dict
1904
1905
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001906 @staticmethod
1907 def evRankMismatch(check=False, **kwargs):
1908 error_name = ErrorIf.RankMismatch
1909 param_reqs = {"rank": None, "dtype": None, "shape": None}
1910 error_result = False
1911 error_reason = "Input Rank does not match output rank"
1912
1913 if check:
1914 input1_shape = kwargs['input1'].shape
1915 input2_shape = kwargs['input2'].shape
1916 output_shape = kwargs['result_tensor'].shape
1917 if (len(input1_shape) != len(output_shape)) or (len(input2_shape) != len(output_shape)):
1918 error_result = True
1919
1920 info_dict = {
1921 "error_name": error_name,
1922 "error_result": error_result,
1923 "error_reason": error_reason,
1924 "param_reqs": param_reqs
1925 }
1926 return info_dict
1927
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001928 @staticmethod
1929 def evInputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001930 op = kwargs['op']
1931 inputDtypes = op['types'].copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001932 # If inputDtypes is a list then only the first two elements are INT8 inputs
1933 if isinstance(inputDtypes, list):
1934 inputDtypes = inputDtypes[2:]
1935
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001936 if DType.INT8 in inputDtypes:
1937 inputDtypes.remove(DType.INT8)
1938 if DType.UINT8 in inputDtypes:
1939 inputDtypes.remove(DType.UINT8)
1940
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001941 error_name = ErrorIf.InputZeroPointNotZero
1942 param_reqs = {
1943 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001944 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001945 "shape": None
1946 }
1947 error_result = False
1948 error_reason = "Input DType not INT8 and zero point not 0"
1949
1950 if check:
1951 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01001952 if isinstance(kwargs['qinfo'], tuple):
1953 qinfo = kwargs['qinfo']
1954 input_zero_point = qinfo[0]
1955 else:
1956 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1957 qinfo = kwargs['qinfo'].ints
1958 input_zero_point = qinfo[0][1]
1959
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001960 if op['op'] == Op.MATMUL:
1961 input1_dtype = kwargs['input_dtype']
1962 input2_dtype = kwargs['input2_dtype']
1963 qinfo = kwargs['qinfo'].ints
1964 input1_zero_point = qinfo[0][1]
1965 input2_zero_point = qinfo[1][1]
1966 if (input1_dtype != DType.INT8 and input1_zero_point != 0) or (input2_dtype != DType.INT8 and input2_zero_point != 0):
1967 error_result = True
1968 else:
1969 if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
1970 error_result = True
1971
1972 info_dict = {
1973 "error_name": error_name,
1974 "error_result": error_result,
1975 "error_reason": error_reason,
1976 "param_reqs": param_reqs
1977 }
1978 return info_dict
1979
1980
1981 @staticmethod
1982 def evWeightZeroPointNotZero(check=False, **kwargs):
1983 op = kwargs['op']
1984
1985 # exclude inputs with INT8 weights
1986 inputDtypes = [t for t in op['types']
1987 if not isinstance(t, list) or t[1] != DType.INT8]
1988
1989 error_name = ErrorIf.WeightZeroPointNotZero
1990 param_reqs = {
1991 "rank": None,
1992 "dtype": inputDtypes,
1993 "shape": None
1994 }
1995 error_result = False
1996 error_reason = "Weight DType not INT8 and zero point not 0"
1997
1998 if check:
1999 weight_dtype = kwargs['weight_dtype']
2000 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
2001 qinfo = kwargs['qinfo'].ints
2002 weight_zero_point = qinfo[1][1]
2003 if weight_dtype != DType.INT8 and weight_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002004 error_result = True
2005
2006 info_dict = {
2007 "error_name": error_name,
2008 "error_result": error_result,
2009 "error_reason": error_reason,
2010 "param_reqs": param_reqs
2011 }
2012 return info_dict
2013
2014
2015 @staticmethod
2016 def evOutputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002017 op = kwargs['op']
2018 inputDtypes = op['types'].copy()
2019 if DType.INT8 in inputDtypes:
2020 inputDtypes.remove(DType.INT8)
2021 if DType.UINT8 in inputDtypes:
2022 inputDtypes.remove(DType.UINT8)
2023
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002024 error_name = ErrorIf.OutputZeroPointNotZero
2025 param_reqs = {
2026 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002027 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002028 "shape": None
2029 }
2030 error_result = False
2031 error_reason = "Output DType not INT8 and zero point not 0"
2032
2033 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002034 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01002035 output_dtype = kwargs['output_dtype']
2036 if isinstance(kwargs['qinfo'], tuple):
2037 qinfo = kwargs['qinfo']
2038 output_zero_point = qinfo[1]
2039 else:
2040 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
2041 qinfo = kwargs['qinfo'].ints
2042 output_zero_point = qinfo[1][1]
2043 if op['op'] == Op.AVG_POOL2D:
2044 if input_dtype != DType.INT8 and output_zero_point != 0:
2045 error_result = True
2046 elif output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002047 error_result = True
2048
2049 info_dict = {
2050 "error_name": error_name,
2051 "error_result": error_result,
2052 "error_reason": error_reason,
2053 "param_reqs": param_reqs
2054 }
2055 return info_dict
2056
Matthew Haddond6ce7252021-09-29 15:35:44 +01002057 @staticmethod
2058 def evAxisSmallerZero(check=False, **kwargs):
2059 error_name = ErrorIf.AxisSmallerZero
2060 param_reqs = {"rank": None, "dtype": None, "shape": None}
2061 error_result = False
2062 error_reason = "Axis smaller than zero"
2063
2064 if check:
2065 axis = kwargs['axis']
2066 if axis < 0:
2067 error_result = True
2068
2069 info_dict = {
2070 "error_name": error_name,
2071 "error_result": error_result,
2072 "error_reason": error_reason,
2073 "param_reqs": param_reqs
2074 }
2075 return info_dict
2076
2077
2078 @staticmethod
2079 def evAxisLargerRank(check=False, **kwargs):
2080 error_name = ErrorIf.AxisLargerRank
2081 param_reqs = {"rank": None, "dtype": None, "shape": None}
2082 error_result = False
2083 error_reason = "Axis larger than rank"
2084
2085 if check:
2086 axis = kwargs['axis']
2087 shape = kwargs['input_shape']
2088 if axis > len(shape):
2089 error_result = True
2090
2091 info_dict = {
2092 "error_name": error_name,
2093 "error_result": error_result,
2094 "error_reason": error_reason,
2095 "param_reqs": param_reqs
2096 }
2097 return info_dict
2098
2099
2100 @staticmethod
2101 def evShapeOfAxisNotOne(check=False, **kwargs):
2102 error_name = ErrorIf.ShapeOfAxisNotOne
2103 param_reqs = {"rank": None, "dtype": None, "shape": None}
2104 error_result = False
2105 error_reason = "shape[axis] is not equal to 1"
2106
2107 if check:
2108 axis = kwargs['axis']
2109 shape = kwargs['output_shape']
2110 if (0 <= axis < len(shape)) and shape[axis] != 1:
2111 error_result = True
2112
2113 info_dict = {
2114 "error_name": error_name,
2115 "error_result": error_result,
2116 "error_reason": error_reason,
2117 "param_reqs": param_reqs
2118 }
2119 return info_dict
2120
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002121
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002122 @staticmethod
2123 def evPadSmallerZero(check=False, **kwargs):
2124 error_name = ErrorIf.PadSmallerZero
2125 param_reqs = {"rank": None, "dtype": None, "shape": None}
2126 error_result = False
2127 error_reason = "At least one pad is smaller than zero"
2128
2129 if check:
Matthew Haddone807aae2021-10-11 18:12:58 +01002130 op = kwargs['op']
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002131 pad = kwargs['pad']
Matthew Haddone807aae2021-10-11 18:12:58 +01002132 if op['op'] == Op.PAD:
2133 for padding in pad:
2134 if min(padding) < 0:
2135 error_result = True
2136 else:
2137 if min(pad) < 0:
2138 error_result = True
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002139
2140 info_dict = {
2141 "error_name": error_name,
2142 "error_result": error_result,
2143 "error_reason": error_reason,
2144 "param_reqs": param_reqs
2145 }
2146 return info_dict
2147
2148
2149 @staticmethod
2150 def evPadLargerEqualKernel(check=False, **kwargs):
2151 error_name = ErrorIf.PadLargerEqualKernel
2152 param_reqs = {"rank": None, "dtype": None, "shape": None}
2153 error_result = False
2154 error_reason = "At least one pad is larger than kernel dimension"
2155
2156 if check:
2157 pad = kwargs['pad']
2158 kernel = kwargs['kernel']
2159 if min(pad) > 0 and min(kernel) > 1:
2160 if pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]:
2161 error_result = True
2162
2163 info_dict = {
2164 "error_name": error_name,
2165 "error_result": error_result,
2166 "error_reason": error_reason,
2167 "param_reqs": param_reqs
2168 }
2169 return info_dict
2170
2171 @staticmethod
2172 def evPoolingOutputShapeMismatch(check=False, **kwargs):
2173 error_name = ErrorIf.PoolingOutputShapeMismatch
2174 param_reqs = {"rank": None, "dtype": None, "shape": None}
2175 error_result = False
2176 error_reason = "Mismatch between output shape provided and expected output shape"
2177
2178 if check:
2179 pad = kwargs['pad']
2180 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
2181
2182 kernel = kwargs['kernel']
2183 kernel_y, kernel_x = kernel[0], kernel[1]
2184
2185 input_shape = kwargs['input_shape']
2186 IH, IW = input_shape[1], input_shape[2]
2187
2188 output_shape = kwargs['output_shape']
2189 OH, OW = output_shape[1], output_shape[2]
2190
2191 stride = kwargs['stride']
2192 stride_y, stride_x = stride[0], stride[1]
2193
2194 # calculate correct height, width dimensions
2195 if stride_x != 0 and stride_y != 0:
2196 y_correct = (IH + pad_top + pad_bottom + stride_y - kernel_y) // stride_y
2197 x_correct = (IW + pad_left + pad_right + stride_x - kernel_x) // stride_x
2198
2199 # ensure parameters are valid
2200 params_valid = (min(kernel) >= 1 and min(stride) >= 1 and min(pad) >= 0
2201 and not (pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]))
2202
2203 if params_valid and (OH != y_correct or OW != x_correct):
2204 error_result = True
2205
2206 info_dict = {
2207 "error_name": error_name,
2208 "error_result": error_result,
2209 "error_reason": error_reason,
2210 "param_reqs": param_reqs
2211 }
2212 return info_dict
2213
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002214 @staticmethod
2215 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
2216 error_name = ErrorIf.ArgmaxOutputShapeMismatch
2217 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2218 error_result = False
2219 error_reason = "Mismatch between output shape provided and expected output shape"
2220
2221 if check:
2222 output_shape = kwargs['output_shape']
2223 input_shape = kwargs['input_shape']
2224 axis = kwargs['axis']
2225
2226 dimension_match = True
2227 axis_shift = 0
2228
2229 # Check that rank is correct before trying to check dimensions
2230 if (len(input_shape) - 1) == len(output_shape):
2231 for i in range(len(input_shape)):
2232 if i == axis:
2233 axis_shift = 1
2234 continue
2235 if input_shape[i] != output_shape[i - axis_shift]:
2236 dimension_match = False
2237
2238 if not dimension_match:
2239 error_result = True
2240
2241 info_dict = {
2242 "error_name": error_name,
2243 "error_result": error_result,
2244 "error_reason": error_reason,
2245 "param_reqs": param_reqs
2246 }
2247 return info_dict
2248
2249 @staticmethod
2250 def evArgmaxOutputRankMismatch(check=False, **kwargs):
2251 error_name = ErrorIf.ArgmaxOutputRankMismatch
2252 param_reqs = {"rank": None, "dtype": None, "shape": None}
2253 error_result = False
2254 error_reason = "Mismatch between output shape provided and expected output shape"
2255
2256 if check:
2257 output_shape = kwargs['output_shape']
2258 input_shape = kwargs['input_shape']
2259 axis = kwargs['axis']
2260 valid_params = axis >= 0 and axis < len(input_shape)
2261
2262 if valid_params and (len(input_shape) - 1) != len(output_shape):
2263 error_result = True
2264
2265 info_dict = {
2266 "error_name": error_name,
2267 "error_result": error_result,
2268 "error_reason": error_reason,
2269 "param_reqs": param_reqs
2270 }
2271 return info_dict
2272
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002273
2274 @staticmethod
2275 def evKernelSmallerOne(check=False, **kwargs):
2276 error_name = ErrorIf.KernelSmallerOne
2277 param_reqs = {"rank": None, "dtype": None, "shape": None}
2278 error_result = False
2279 error_reason = "At least one kernel dimension is smaller than zero"
2280
2281 if check:
2282 kernel = kwargs['kernel']
2283 if min(kernel) < 1:
2284 error_result = True
2285
2286 info_dict = {
2287 "error_name": error_name,
2288 "error_result": error_result,
2289 "error_reason": error_reason,
2290 "param_reqs": param_reqs
2291 }
2292 return info_dict
2293
2294 @staticmethod
2295 def evStrideSmallerOne(check=False, **kwargs):
2296 error_name = ErrorIf.StrideSmallerOne
2297 param_reqs = {"rank": None, "dtype": None, "shape": None}
2298 error_result = False
2299 error_reason = "At least one stride dimension is smaller than zero"
2300
2301 if check:
2302 stride = kwargs['stride']
2303 if min(stride) < 1:
2304 error_result = True
2305
2306 info_dict = {
2307 "error_name": error_name,
2308 "error_result": error_result,
2309 "error_reason": error_reason,
2310 "param_reqs": param_reqs
2311 }
2312 return info_dict
2313
Matthew Haddonc2025212021-10-08 21:21:05 +01002314 @staticmethod
2315 def evScaleTrue(check=False, **kwargs):
2316 error_name = ErrorIf.ScaleTrue
2317 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
2318 error_result = False
2319 error_reason = "Scale set to true but input type is INT48"
2320
2321 if check:
2322 input_dtype = kwargs['input_dtype']
2323 scale32 = kwargs['scale32']
2324 if scale32 and input_dtype == DType.INT48:
2325 error_result = True
2326
2327 info_dict = {
2328 "error_name": error_name,
2329 "error_result": error_result,
2330 "error_reason": error_reason,
2331 "param_reqs": param_reqs
2332 }
2333 return info_dict
2334
2335 @staticmethod
2336 def evScaleNotTrue(check=False, **kwargs):
2337 error_name = ErrorIf.ScaleNotTrue
2338 param_reqs = {"rank": None, "dtype": None, "shape": None}
2339 error_result = False
2340 error_reason = "Scale set to false but double round set to true"
2341
2342 if check:
2343 scale32 = kwargs['scale32']
2344 double_round = kwargs['double_round']
2345 if not scale32 and double_round:
2346 error_result = True
2347
2348 info_dict = {
2349 "error_name": error_name,
2350 "error_result": error_result,
2351 "error_reason": error_reason,
2352 "param_reqs": param_reqs
2353 }
2354 return info_dict
2355
Matthew Haddone807aae2021-10-11 18:12:58 +01002356 @staticmethod
2357 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
2358 error_name = ErrorIf.TensorSizeInputOutputMismatch
2359 param_reqs = {"rank": None, "dtype": None, "shape": None}
2360 error_result = False
2361 error_reason = "Input tensor size does not match output tensor size"
2362
2363 if check:
2364 input_shape = kwargs['input_shape']
2365 output_shape = kwargs['output_shape']
2366 input_size = np.prod(input_shape)
2367 output_size = np.prod(output_shape)
2368 if input_size != output_size:
2369 error_result = True
2370
2371 info_dict = {
2372 "error_name": error_name,
2373 "error_result": error_result,
2374 "error_reason": error_reason,
2375 "param_reqs": param_reqs
2376 }
2377 return info_dict
2378
2379 @staticmethod
2380 def evStartSmallerZero(check=False, **kwargs):
2381 error_name = ErrorIf.StartSmallerZero
2382 param_reqs = {"rank": None, "dtype": None, "shape": None}
2383 error_result = False
2384 error_reason = "Starting point smaller than zero"
2385
2386 if check:
2387 input_shape = kwargs['input_shape']
2388 start = kwargs['start']
2389 rank = len(input_shape)
2390 if len(start) == rank:
2391 for index in range(rank):
2392 if start[index] < 0:
2393 error_result = True
2394
2395 info_dict = {
2396 "error_name": error_name,
2397 "error_result": error_result,
2398 "error_reason": error_reason,
2399 "param_reqs": param_reqs
2400 }
2401 return info_dict
2402
2403
2404 @staticmethod
2405 def evSizeSmallerEqualZero(check=False, **kwargs):
2406 error_name = ErrorIf.SizeSmallerEqualZero
2407 param_reqs = {"rank": None, "dtype": None, "shape": None}
2408 error_result = False
2409 error_reason = "Size smaller than or equal to zero"
2410
2411 if check:
2412 input_shape = kwargs['input_shape']
2413 size = kwargs['size']
2414 rank = len(input_shape)
2415 if len(size) == rank:
2416 for index in range(rank):
2417 if size[index] <= 0:
2418 error_result = True
2419
2420 info_dict = {
2421 "error_name": error_name,
2422 "error_result": error_result,
2423 "error_reason": error_reason,
2424 "param_reqs": param_reqs
2425 }
2426 return info_dict
2427
2428
2429 @staticmethod
2430 def evStartSizeOutsideBounds(check=False, **kwargs):
2431 error_name = ErrorIf.StartSizeOutsideBounds
2432 param_reqs = {"rank": None, "dtype": None, "shape": None}
2433 error_result = False
2434 error_reason = "starting point plus size larger than input dimension"
2435
2436 if check:
2437 input_shape = kwargs['input_shape']
2438 start = kwargs['start']
2439 size = kwargs['size']
2440 rank = len(input_shape)
2441 if len(start) == rank and len(size) == rank:
2442 for index in range(rank):
2443 if start[index] + size[index] > input_shape[index]:
2444 error_result = True
2445
2446 info_dict = {
2447 "error_name": error_name,
2448 "error_result": error_result,
2449 "error_reason": error_reason,
2450 "param_reqs": param_reqs
2451 }
2452 return info_dict
2453
2454
2455 @staticmethod
2456 def evSizeOutputShapeMismatch(check=False, **kwargs):
2457 error_name = ErrorIf.SizeOutputShapeMismatch
2458 param_reqs = {"rank": None, "dtype": None, "shape": None}
2459 error_result = False
2460 error_reason = "Size does not match output dimension"
2461
2462 if check:
2463 input_shape = kwargs['input_shape']
2464 output_shape = kwargs['output_shape']
2465 size = kwargs['size']
2466 rank = len(input_shape)
2467 if len(size) == rank:
2468 for index in range(rank):
2469 if size[index] != output_shape[index]:
2470 error_result = True
2471
2472 info_dict = {
2473 "error_name": error_name,
2474 "error_result": error_result,
2475 "error_reason": error_reason,
2476 "param_reqs": param_reqs
2477 }
2478 return info_dict
2479
2480 @staticmethod
2481 def evInputSizeStartLengthMismatch(check=False, **kwargs):
2482 error_name = ErrorIf.InputSizeStartLengthMismatch
2483 param_reqs = {"rank": None, "dtype": None, "shape": None}
2484 error_result = False
2485 error_reason = "rank of input not equal to length of start or size"
2486
2487 if check:
2488 input_shape = kwargs['input_shape']
2489 start = kwargs['start']
2490 size = kwargs['size']
2491 rank = len(input_shape)
2492 if rank != len(start) or rank != len(size):
2493 error_result = True
2494
2495 info_dict = {
2496 "error_name": error_name,
2497 "error_result": error_result,
2498 "error_reason": error_reason,
2499 "param_reqs": param_reqs
2500 }
2501 return info_dict
2502
2503 @staticmethod
2504 def evIndexOutsideBounds(check=False, **kwargs):
2505 error_name = ErrorIf.IndexOutsideBounds
2506 param_reqs = {"rank": None, "dtype": None, "shape": None}
2507 error_result = False
2508 error_reason = "Index outside of allowed bounds"
2509
2510 if check:
2511 input_shape = kwargs['input_shape']
2512 perms = kwargs['perms']
2513 rank = len(input_shape)
2514
2515 for index in perms:
2516 if index < 0 or index > rank:
2517 error_result = True
2518
2519 info_dict = {
2520 "error_name": error_name,
2521 "error_result": error_result,
2522 "error_reason": error_reason,
2523 "param_reqs": param_reqs
2524 }
2525 return info_dict
2526
2527 @staticmethod
2528 def evIndexUsedTwice(check=False, **kwargs):
2529 error_name = ErrorIf.IndexUsedTwice
2530 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2531 error_result = False
2532 error_reason = "Index used multiple times"
2533
2534 if check:
2535 input_shape = kwargs['input_shape']
2536 perms = kwargs['perms']
2537 rank = len(input_shape)
2538
2539 unique_indices = []
2540 for index in perms:
2541 if index in unique_indices:
2542 error_result = True
2543 else:
2544 unique_indices.append(index)
2545
2546 info_dict = {
2547 "error_name": error_name,
2548 "error_result": error_result,
2549 "error_reason": error_reason,
2550 "param_reqs": param_reqs
2551 }
2552 return info_dict
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002553
2554
Matthew Haddonb724efc2021-08-25 16:40:29 +01002555class TosaInvalidValidator:
2556
2557 @staticmethod
2558 def ivWrongDataTypeOrModeResize(**kwargs):
2559 input_dtype = kwargs["input_dtype"]
2560 args = kwargs["args"]
2561 mode = args[0]
2562 stride = args[1]
2563 stride_fp = args[4]
2564 output_dtype = args[8]
2565
2566 if mode == ResizeMode.BILINEAR:
2567 # Invalid output data type / Invalid input datatype
2568 return (
2569 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
2570 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
2571 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
2572 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2573 )
2574 elif mode == ResizeMode.NEAREST:
2575 # Invalid output data type / Invalid input datatype
2576 return (
2577 (input_dtype != output_dtype) or
2578 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2579 )
2580 else:
2581 # Invalid resize mode
2582 return True
2583
2584 @staticmethod
2585 def ivBadStride(**kwargs):
2586 input_dtype = kwargs["input_dtype"]
2587 args = kwargs["args"]
2588 stride_x = args[1][0]
2589 stride_y = args[1][1]
2590 stride_fp_x = args[4][0]
2591 stride_fp_y = args[4][1]
2592
2593 if input_dtype == DType.FLOAT:
2594 if stride_fp_x <= 0 or stride_fp_y <= 0:
2595 # Negative or zero stride
2596 return True
2597 else:
2598 if stride_x <= 0 or stride_y <= 0:
2599 # Negative or zero stride
2600 return True
2601 return False
2602
2603
Matthew Haddonb724efc2021-08-25 16:40:29 +01002604 @staticmethod
2605 def ivHeightWidthSmallerZero(**kwargs):
2606 opName = kwargs['opName']
2607
2608 inputShapes = kwargs['shapeList']
2609 input = inputShapes[0]
2610 if not opName.endswith("pool2d"):
2611 filter = inputShapes[1]
2612
2613 args = kwargs['args']
2614 strides = args[0]
2615 padding = args[1]
2616 dilations = args[2]
2617 if opName.endswith("pool2d"):
2618 kernel = args[2]
2619
2620 if opName.startswith('conv2d'):
2621 h = (
2622 input[1]
2623 - filter[1]
2624 - (filter[1] - 1) * (dilations[0] - 1)
2625 + padding[0]
2626 + padding[1]
2627 ) // strides[0] + 1
2628
2629 w = (
2630 input[2]
2631 - filter[2]
2632 - (filter[2] - 1) * (dilations[1] - 1)
2633 + padding[2]
2634 + padding[3]
2635 ) // strides[1] + 1
2636 elif opName.startswith("depthwise_conv2d"):
2637 h = (
2638 input[1]
2639 - filter[0]
2640 - (filter[0] - 1) * (dilations[0] - 1)
2641 + padding[0]
2642 + padding[1]
2643 ) // strides[0] + 1
2644
2645 w = (
2646 input[2]
2647 - filter[1]
2648 - (filter[1] - 1) * (dilations[1] - 1)
2649 + padding[2]
2650 + padding[3]
2651 ) // strides[1] + 1
2652 elif opName.endswith("pool2d"):
2653 h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
2654 w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
2655 else:
2656 assert False, "Unrecognized Op"
2657
2658 if h <= 0 or w <= 0:
2659 # Invalid parameter combination
2660 return True
2661 return False
2662
2663 @staticmethod
2664 def ivNonPositiveOutputShape(**kwargs):
2665 args = kwargs['args']
2666 output_shape = args[3]
2667 if output_shape[1] <= 0 or output_shape[2] <= 0:
2668 # Negative output shape
2669 return True
2670 return False
2671
2672
Kevin Cheng550ccc52021-03-03 11:21:43 -08002673
Eric Kunzee5e26762020-10-13 16:11:07 -07002674class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002675 # Maximum rank of tensor supported by test generator.
2676 TOSA_TENSOR_MAX_RANK = 6
2677
Eric Kunzee5e26762020-10-13 16:11:07 -07002678 def __init__(self, args):
2679 self.args = args
2680 self.basePath = args.output_dir
2681 self.random_seed = args.random_seed
2682 self.ser = None
2683 self.rng = np.random.default_rng(self.random_seed)
2684 self.createDynamicOpLists()
2685 self.initOpListDefaults()
2686 self.quantGen = TosaQuantGen()
2687 # Force makeShape to do a specific starting shape
2688 self.targetted_shape = None
2689
2690 def createSerializer(self, opName, testPath):
2691 self.testPath = os.path.join(opName, testPath)
2692
2693 fullPath = os.path.join(self.basePath, self.testPath)
2694 os.makedirs(fullPath, exist_ok=True)
2695 self.ser = ts.TosaSerializer(fullPath)
2696
2697 def getSerializer(self):
2698 return self.ser
2699
2700 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002701 with open(
2702 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
2703 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07002704 fd.write(self.ser.serialize())
2705
Kevin Cheng550ccc52021-03-03 11:21:43 -08002706 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
2707 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07002708
Matthew Haddon74567092021-07-16 15:38:20 +01002709 def resetRNG(self, seed=None):
2710 if seed == None:
2711 seed = self.random_seed + 1
2712 self.rng = np.random.default_rng(seed)
2713
Eric Kunzee5e26762020-10-13 16:11:07 -07002714 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07002715 if dtype == DType.BOOL:
2716 np_dt = np.bool
2717 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07002718 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07002719 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07002720 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002721 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002722 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
2723 elif dtype == DType.UINT8:
2724 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002725 elif dtype == DType.INT16:
2726 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
2727 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002728 return np.int32(
2729 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
2730 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002731 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002732 return np.int64(
2733 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
2734 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002735 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002736 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002737 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002738 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002739
Kevin Cheng989cb052021-04-28 16:29:44 -07002740 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07002741 placeholders = []
2742
Kevin Cheng989cb052021-04-28 16:29:44 -07002743 assert len(shape_list) == len(dtype_list)
2744
2745 for idx, shape in enumerate(shape_list):
2746 arr = self.getRandTensor(shape, dtype_list[idx])
2747 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002748
2749 return placeholders
2750
Kevin Cheng989cb052021-04-28 16:29:44 -07002751 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07002752 consts = []
2753
Kevin Cheng989cb052021-04-28 16:29:44 -07002754 assert len(shape_list) == len(dtype_list)
2755
2756 for idx, shape in enumerate(shape_list):
2757 arr = self.getRandTensor(shape, dtype_list[idx])
2758 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002759
2760 return consts
2761
2762 def makeShape(self, rank):
2763 if self.targetted_shape:
2764 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002765 return np.int32(
2766 self.rng.integers(
2767 low=self.args.tensor_shape_range[0],
2768 high=self.args.tensor_shape_range[1],
2769 size=rank,
2770 )
2771 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002772
2773 def setTargetShape(self, shape):
2774 self.targetted_shape = shape
2775
2776 def randInt(self, low=0, high=256):
2777 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
2778
2779 def getRandNumberDType(self, dtype):
2780 if dtype == DType.FLOAT:
2781 return self.rng.random()
2782 elif dtype == DType.BOOL:
2783 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07002784 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07002785 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07002786 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07002787 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002788 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07002789 elif dtype == DType.INT16:
2790 low, high = (-32768, 32768)
2791 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002792 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07002793 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002794 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07002795 # Special size
2796 return np.int64(self.rng.integers(low, high, size=1))[0]
2797 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002798 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002799
2800 return np.int32(self.rng.integers(low, high, size=1))[0]
2801
2802 def shapeStr(self, shape):
2803
2804 sStr = []
2805 # Convert to strings
2806 for i in shape:
2807 sStr.append(str(i))
2808
Kevin Cheng550ccc52021-03-03 11:21:43 -08002809 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002810
2811 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07002812 if isinstance(t, list):
2813 assert len(t) >= 2
2814 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002815 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002816 if t == DType.BOOL:
2817 return "b"
2818 elif t == DType.INT4:
2819 return "i4"
2820 elif t == DType.INT8:
2821 return "i8"
2822 elif t == DType.UINT8:
2823 return "u8"
2824 elif t == DType.INT16:
2825 return "i16"
2826 elif t == DType.INT32:
2827 return "i32"
2828 elif t == DType.INT48:
2829 return "i48"
2830 elif t == DType.FLOAT:
2831 return "float"
2832 else:
2833 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002834
2835 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002836 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08002837 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07002838 return 4
2839 elif t == DType.INT8:
2840 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08002841 elif t == DType.UINT8:
2842 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07002843 elif t == DType.INT16:
2844 return 16
2845 elif t == DType.INT32:
2846 return 32
2847 elif t == DType.INT48:
2848 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +01002849 elif t == DType.FLOAT:
2850 return 32
2851 elif t == DType.BOOL:
2852 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002853 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002854 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002855
2856 # Argument generators
2857 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
2858 # Where the string descriptor is used to generate the test name and
2859 # The build_fcn_arg_list is expanded and passed to the operator test
2860 # build function
2861
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002862 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
2863 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
2864
Matthew Haddon848efb42021-09-09 12:30:53 +01002865 # build_placeholder returns an int, ABS/other ops does not
2866 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002867 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
2868 return result_tens
2869 elif op['op'] == Op.IDENTITY:
2870 self.ser.addOperator(op['op'], a.name, result_tens.name, None, qinfo)
2871 return result_tens
2872
2873 # Ensure new output type has correct qinfo
2874 if error_name == ErrorIf.WrongOutputType:
2875 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
2876 qinfo = ts.TosaSerializerQuantInfo()
2877 qinfo.UnaryQuantInfo(
2878 TosaQuantGen.getQinfo(self, a.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
2879 )
2880
2881 # Invalidate Input/Output list for error if checks.
2882 input_list = [a.name]
2883 output_list = [result_tens.name]
2884 pCount, cCount = op["operands"]
2885 num_operands = pCount + cCount
2886 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2887
2888 TosaErrorValidator.evValidateErrorIfs(
2889 self.ser,
2890 validator_fcns,
2891 error_name,
2892 op=op,
2893 input_dtype=a.dtype,
2894 output_dtype=result_tens.dtype,
2895 qinfo = qinfo,
2896 result_tensor = result_tens,
2897 input_list=input_list,
2898 output_list=output_list,
2899 num_operands=num_operands,
2900 )
2901
2902 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002903 return result_tens
2904
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002905 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
2906 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
2907
2908
2909 # Invalidate Input/Output list for error if checks.
2910 input_list = [a.name, b.name]
2911 output_list = [result_tens.name]
2912 pCount, cCount = op["operands"]
2913 num_operands = pCount + cCount
2914 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2915
2916 TosaErrorValidator.evValidateErrorIfs(
2917 self.ser,
2918 validator_fcns,
2919 error_name,
2920 op=op,
2921 input1 = a,
2922 input2 = b,
2923 input_dtype = a.dtype,
2924 output_dtype = result_tens.dtype,
2925 result_tensor = result_tens,
2926 input_list=input_list,
2927 output_list=output_list,
2928 num_operands=num_operands,
2929 )
2930
2931 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07002932 return result_tens
2933
2934 def build_binary_nonbroadcast(self, op, a, b):
2935 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002936 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002937 return result_tens
2938
Kevin Chengaee1fac2020-11-11 13:54:06 -08002939 def build_arithmetic_right_shift(self, op, a, b, round):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002940 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002941
2942 attr = ts.TosaSerializerAttribute()
2943 attr.ArithmeticRightShiftAttribute(round)
2944
Matthew Haddon848efb42021-09-09 12:30:53 +01002945 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002946 return result_tens
2947
2948 def build_mul(self, op, a, b, shift):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002949 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Eric Kunzee5e26762020-10-13 16:11:07 -07002950
2951 # Special for multiply:
2952 # Force the result to INT32 for INT types
2953 if a.dtype != DType.FLOAT:
2954 result_tens.setDtype(DType.INT32)
2955
Kevin Chengaee1fac2020-11-11 13:54:06 -08002956 attr = ts.TosaSerializerAttribute()
2957 attr.MulAttribute(shift)
2958
Matthew Haddon848efb42021-09-09 12:30:53 +01002959 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002960 return result_tens
2961
Kevin Chengfe392ce2021-10-18 21:51:55 +00002962 def build_table(self, op, a, table):
2963 result_tens = OutputShaper.tableOp(self.ser, a)
Eric Kunzee5e26762020-10-13 16:11:07 -07002964
Kevin Chengfe392ce2021-10-18 21:51:55 +00002965 attr = ts.TosaSerializerAttribute()
2966 attr.TableAttribute(table)
2967
2968 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002969
2970 return result_tens
2971
2972 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07002973 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002974 self.ser.addOperator(op['op'], [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002975 return result_tens
2976
2977 def build_comparison(self, op, a, b):
2978 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002979 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002980 return result_tens
2981
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002982 def build_argmax(self, op, a, axis, validator_fcns, error_name):
2983 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
2984
2985 # Invalidate Input/Output list for error if checks.
2986 input_list = [a.name]
2987 output_list = [result_tens.name]
2988 pCount, cCount = op["operands"]
2989 num_operands = pCount + cCount
2990 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2991
2992 TosaErrorValidator.evValidateErrorIfs(
2993 self.ser,
2994 validator_fcns,
2995 error_name,
2996 op=op,
2997 axis=axis,
2998 input_shape = a.shape,
2999 input_dtype = a.dtype,
3000 output_shape = result_tens.shape,
3001 output_dtype = result_tens.dtype,
3002 result_tensor = result_tens,
3003 input_list=input_list,
3004 output_list=output_list,
3005 num_operands=num_operands,
3006 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003007
3008 attr = ts.TosaSerializerAttribute()
3009 attr.AxisAttribute(axis)
3010
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003011 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003012 return result_tens
3013
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003014 def build_pool2d(self, op, input, stride, pad, kernel, validator_fcns=None, error_name=None, qinfo=None):
3015 result_tens = OutputShaper.pool2dOp(self.ser, self.rng, input, kernel, stride, pad, error_name)
3016
3017 # Ensure new output type has correct qinfo
3018 if error_name == ErrorIf.WrongInputType:
3019 if input.dtype not in [DType.INT8, DType.UINT8]:
3020 qinfo = ts.TosaSerializerQuantInfo()
3021 qinfo.UnaryQuantInfo(
3022 TosaQuantGen.getQinfo(self, input.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
3023 )
3024
3025 # Invalidate Input/Output list for error if checks.
3026 input_list = [input.name]
3027 output_list = [result_tens.name]
3028 pCount, cCount = op["operands"]
3029 num_operands = pCount + cCount
3030 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3031
3032 TosaErrorValidator.evValidateErrorIfs(
3033 self.ser,
3034 validator_fcns,
3035 error_name,
3036 op=op,
3037 input_shape=input.shape,
3038 input_dtype=input.dtype,
3039 output_shape=result_tens.shape,
3040 output_dtype=result_tens.dtype,
3041 kernel=kernel,
3042 stride=stride,
3043 pad=pad,
3044 qinfo = qinfo,
3045 result_tensor = result_tens,
3046 input_list=input_list,
3047 output_list=output_list,
3048 num_operands=num_operands,
3049 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003050
3051 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003052 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07003053
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003054 self.ser.addOperator(op['op'], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003055 return result_tens
3056
3057 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003058 assert len(padding) == 4
3059 result_tens = OutputShaper.conv2dOp(
3060 self.ser, ifm, filter, strides, padding, dilations
3061 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003062
3063 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003064 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07003065
Kevin Cheng550ccc52021-03-03 11:21:43 -08003066 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003067 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003068 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003069 return result_tens
3070
Kevin Cheng1533b852021-09-01 12:51:58 -07003071 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
3072 assert len(padding) == 6
3073 result_tens = OutputShaper.conv3dOp(
3074 self.ser, ifm, filter, strides, padding, dilations
3075 )
3076
3077 attr = ts.TosaSerializerAttribute()
3078 attr.ConvAttribute(padding, strides, dilations)
3079
3080 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003081 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng1533b852021-09-01 12:51:58 -07003082 )
3083 return result_tens
3084
Kevin Cheng550ccc52021-03-03 11:21:43 -08003085 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07003086 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003087 ):
3088 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07003089 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
3090
3091 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003092 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003093
Kevin Cheng550ccc52021-03-03 11:21:43 -08003094 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003095 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003096 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003097 return result_tens
3098
Kevin Cheng550ccc52021-03-03 11:21:43 -08003099 def build_depthwise_conv2d(
3100 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
3101 ):
3102 result_tens = OutputShaper.depthwiseConv2dOp(
3103 self.ser, ifm, filter, strides, padding, dilations
3104 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003105
3106 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003107 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07003108
Kevin Cheng550ccc52021-03-03 11:21:43 -08003109 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003110 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003111 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003112 return result_tens
3113
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003114 def build_fully_connected(self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None):
3115 result_tens = OutputShaper.fullyConnectedOp(self.ser, self.rng, ifm, filter, error_name)
3116
3117 # Invalidate Input/Output list for error if checks.
3118 input_list = [ifm.name, filter.name, bias.name]
3119 output_list = [result_tens.name]
3120 pCount, cCount = op["operands"]
3121 num_operands = pCount + cCount
3122 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3123
3124 TosaErrorValidator.evValidateErrorIfs(
3125 self.ser,
3126 validator_fcns,
3127 error_name,
3128 op=op,
3129 input_shape=ifm.shape,
3130 input_dtype=ifm.dtype,
3131 weight_dtype=filter.dtype,
3132 output_shape=result_tens.shape,
3133 output_dtype=result_tens.dtype,
3134 qinfo = qinfo,
3135 result_tensor = result_tens,
3136 input_list=input_list,
3137 output_list=output_list,
3138 num_operands=num_operands,
3139 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003140
Kevin Cheng550ccc52021-03-03 11:21:43 -08003141 self.ser.addOperator(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003142 op['op'], input_list, output_list, None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003143 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003144 return result_tens
3145
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003146 def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
3147 result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
3148
3149 # Invalidate Input/Output list for error if checks.
3150 input_list = [a.name, b.name]
3151 output_list = [result_tens.name]
3152 pCount, cCount = op["operands"]
3153 num_operands = pCount + cCount
3154 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3155
3156 TosaErrorValidator.evValidateErrorIfs(
3157 self.ser,
3158 validator_fcns,
3159 error_name,
3160 op=op,
3161 input_shape=a.shape,
3162 input_dtype=a.dtype,
3163 input2_shape=b.shape,
3164 input2_dtype=b.dtype,
3165 output_shape=result_tens.shape,
3166 output_dtype=result_tens.dtype,
3167 qinfo = qinfo,
3168 result_tensor = result_tens,
3169 input_list=input_list,
3170 output_list=output_list,
3171 num_operands=num_operands,
3172 )
3173
3174 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003175 return result_tens
3176
Matthew Haddond6ce7252021-09-29 15:35:44 +01003177 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
3178 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
3179
3180 # Invalidate Input/Output list for error if checks.
3181 input_list = [a.name]
3182 output_list = [result_tens.name]
3183 pCount, cCount = op["operands"]
3184 num_operands = pCount + cCount
3185 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3186
3187 TosaErrorValidator.evValidateErrorIfs(
3188 self.ser,
3189 validator_fcns,
3190 error_name,
3191 op=op,
3192 axis = axis,
3193 input_shape = a.shape,
3194 output_shape = result_tens.shape,
3195 input_dtype = a.dtype,
3196 output_dtype = result_tens.dtype,
3197 result_tensor = result_tens,
3198 input_list=input_list,
3199 output_list=output_list,
3200 num_operands=num_operands,
3201 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003202
3203 attr = ts.TosaSerializerAttribute()
3204 attr.AxisAttribute(axis)
3205
Matthew Haddond6ce7252021-09-29 15:35:44 +01003206 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003207 return result_tens
3208
3209 def build_clamp(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003210 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003211
3212 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01003213 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07003214
3215 if a.dtype == DType.FLOAT:
3216 attr.ClampAttribute(0, 0, min(v), max(v))
3217 else:
3218 attr.ClampAttribute(min(v), max(v), 0, 0)
3219
Matthew Haddon848efb42021-09-09 12:30:53 +01003220 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003221 return result_tens
3222
3223 def build_leaky_relu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003224 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003225 attr = ts.TosaSerializerAttribute()
3226
3227 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
3228
Matthew Haddon848efb42021-09-09 12:30:53 +01003229 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003230 return result_tens
3231
3232 # Needs an additional type/input
3233 def build_prelu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003234 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003235
Matthew Haddon848efb42021-09-09 12:30:53 +01003236 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003237 return result_tens
3238
Eric Kunzee5e26762020-10-13 16:11:07 -07003239 def build_sigmoid(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003240 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01003241 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003242 return result_tens
3243
3244 def build_tanh(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003245 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01003246 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003247 return result_tens
3248
Matthew Haddon818ab902021-07-27 09:12:49 +01003249 def build_concat(self, op, *a):
Kevin Cheng93a16282021-08-31 16:14:03 -07003250 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01003251
3252 # To store variable length list of input tensors we need to store axis along with it
3253 axis = a[-1]
3254 a = a[:-1]
3255
3256 result_tens = OutputShaper.concatOp(self.ser, axis, *a)
Eric Kunzee5e26762020-10-13 16:11:07 -07003257
3258 attr = ts.TosaSerializerAttribute()
3259 attr.AxisAttribute(axis)
3260
Matthew Haddon818ab902021-07-27 09:12:49 +01003261 input_tensor_names = []
3262 for tensor in a:
3263 input_tensor_names.append(tensor.name)
3264
Matthew Haddon848efb42021-09-09 12:30:53 +01003265 self.ser.addOperator(op['op'], input_tensor_names, [result_tens.name], attr)
3266 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07003267
Kevin Chengfe392ce2021-10-18 21:51:55 +00003268 def build_pad(self, op, a, padding, pad_const_int, pad_const_float, validator_fcns=None, error_name=None, qinfo=None):
Matthew Haddone807aae2021-10-11 18:12:58 +01003269 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003270
Kevin Chengfe392ce2021-10-18 21:51:55 +00003271 attr = ts.TosaSerializerAttribute()
3272 attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07003273
Matthew Haddone807aae2021-10-11 18:12:58 +01003274 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00003275 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01003276 output_list = [result_tens.name]
3277 pCount, cCount = op["operands"]
3278 num_operands = pCount + cCount
3279 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3280
3281 TosaErrorValidator.evValidateErrorIfs(
3282 self.ser,
3283 validator_fcns,
3284 error_name,
3285 op=op,
3286 input_shape = a.shape,
3287 output_shape = result_tens.shape,
3288 input_dtype = a.dtype,
3289 output_dtype = result_tens.dtype,
3290 pad=padding,
3291 qinfo=qinfo,
3292 result_tensor = result_tens,
3293 input_list=input_list,
3294 output_list=output_list,
3295 num_operands=num_operands,
3296 )
3297
Kevin Cheng550ccc52021-03-03 11:21:43 -08003298 self.ser.addOperator(
Kevin Chengfe392ce2021-10-18 21:51:55 +00003299 op['op'], input_list, output_list, attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003300 )
Matthew Haddone86fd342021-09-07 16:12:21 +01003301 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07003302
Matthew Haddone807aae2021-10-11 18:12:58 +01003303 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
3304 result_tens = OutputShaper.reshapeOp(self.ser, self.rng, a, newShape, error_name)
3305
3306 # Invalidate Input/Output list for error if checks.
3307 input_list = [a.name]
3308 output_list = [result_tens.name]
3309 pCount, cCount = op["operands"]
3310 num_operands = pCount + cCount
3311 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3312
3313 TosaErrorValidator.evValidateErrorIfs(
3314 self.ser,
3315 validator_fcns,
3316 error_name,
3317 op=op,
3318 input_shape = a.shape,
3319 output_shape = result_tens.shape,
3320 input_dtype = a.dtype,
3321 output_dtype = result_tens.dtype,
3322 result_tensor = result_tens,
3323 input_list=input_list,
3324 output_list=output_list,
3325 num_operands=num_operands,
3326 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003327
3328 attr = ts.TosaSerializerAttribute()
3329 attr.ReshapeAttribute(newShape)
3330
Matthew Haddone807aae2021-10-11 18:12:58 +01003331 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003332 return result_tens
3333
3334 def build_reverse(self, op, a, axis):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003335 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07003336
3337 attr = ts.TosaSerializerAttribute()
3338 attr.AxisAttribute(axis)
3339
Matthew Haddon848efb42021-09-09 12:30:53 +01003340 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003341 return result_tens
3342
Matthew Haddone807aae2021-10-11 18:12:58 +01003343 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
3344 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003345
Kevin Chengfe392ce2021-10-18 21:51:55 +00003346 attr = ts.TosaSerializerAttribute()
3347 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07003348
Matthew Haddone807aae2021-10-11 18:12:58 +01003349 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00003350 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01003351 output_list = [result_tens.name]
3352 pCount, cCount = op["operands"]
3353 num_operands = pCount + cCount
3354 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3355
3356 TosaErrorValidator.evValidateErrorIfs(
3357 self.ser,
3358 validator_fcns,
3359 error_name,
3360 op=op,
3361 input_shape = a.shape,
3362 output_shape = result_tens.shape,
3363 perms=perms,
3364 input_dtype = a.dtype,
3365 output_dtype = result_tens.dtype,
3366 result_tensor = result_tens,
3367 input_list=input_list,
3368 output_list=output_list,
3369 num_operands=num_operands,
3370 )
3371
3372
Kevin Chengfe392ce2021-10-18 21:51:55 +00003373 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003374 return result_tens
3375
Matthew Haddone807aae2021-10-11 18:12:58 +01003376 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
3377 result_tens = OutputShaper.sliceOp(self.ser, self.rng, a, start, size, error_name)
3378
3379 # Invalidate Input/Output list for error if checks.
3380 input_list = [a.name]
3381 output_list = [result_tens.name]
3382 pCount, cCount = op["operands"]
3383 num_operands = pCount + cCount
3384 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3385
3386 TosaErrorValidator.evValidateErrorIfs(
3387 self.ser,
3388 validator_fcns,
3389 error_name,
3390 op=op,
3391 input_shape = a.shape,
3392 output_shape = result_tens.shape,
3393 input_dtype = a.dtype,
3394 output_dtype = result_tens.dtype,
3395 start=start,
3396 size=size,
3397 result_tensor = result_tens,
3398 input_list=input_list,
3399 output_list=output_list,
3400 num_operands=num_operands,
3401 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003402
3403 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01003404 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07003405
Matthew Haddone807aae2021-10-11 18:12:58 +01003406 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003407 return result_tens
3408
3409 def build_tile(self, op, a, multiples):
3410 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
3411
3412 attr = ts.TosaSerializerAttribute()
3413 attr.TileAttribute(multiples)
3414
Matthew Haddon848efb42021-09-09 12:30:53 +01003415 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003416 return result_tens
3417
Kevin Cheng77d0f762020-11-24 10:26:32 -08003418 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07003419
3420 # Create a new indicies tensor
3421 # here with data that doesn't exceed the dimensions of the values tensor
3422
Kevin Cheng550ccc52021-03-03 11:21:43 -08003423 K = values.shape[1] # K
3424 W = self.randInt(
3425 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
3426 ) # W
3427 indicies_arr = np.int32(
3428 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
3429 ) # (N, W)
3430 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003431
Kevin Cheng77d0f762020-11-24 10:26:32 -08003432 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07003433
Matthew Haddon848efb42021-09-09 12:30:53 +01003434 self.ser.addOperator(op['op'], [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003435
3436 return result_tens
3437
Kevin Cheng77d0f762020-11-24 10:26:32 -08003438 def build_scatter(self, op, values_in, input):
3439
3440 # Create a new indicies tensor
3441 # here with data that doesn't exceed the dimensions of the values_in tensor
3442
Kevin Cheng550ccc52021-03-03 11:21:43 -08003443 K = values_in.shape[1] # K
3444 W = input.shape[1] # W
3445 indicies_arr = np.int32(
3446 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
3447 ) # (N, W)
3448 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08003449
3450 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
3451
Kevin Cheng550ccc52021-03-03 11:21:43 -08003452 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003453 op['op'], [values_in.name, indicies.name, input.name], [result_tens.name]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003454 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08003455
3456 return result_tens
3457
Matthew Haddon848efb42021-09-09 12:30:53 +01003458
Kevin Cheng550ccc52021-03-03 11:21:43 -08003459 def build_resize(
3460 self,
3461 op,
3462 input,
3463 mode,
3464 stride,
3465 offset,
3466 shift,
3467 stride_fp,
3468 offset_fp,
3469 output_dims,
3470 input_dtype,
3471 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01003472 validator_fcns,
3473 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003474 ):
3475 result_tens = OutputShaper.resizeOp(
3476 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003477 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003478 input,
3479 mode,
3480 stride,
3481 offset,
3482 shift,
3483 stride_fp,
3484 offset_fp,
3485 output_dims,
3486 input_dtype,
3487 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01003488 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08003489 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003490
Matthew Haddon848efb42021-09-09 12:30:53 +01003491 # Invalidate Input/Output list for error if checks.
3492 input_list = [input.name]
3493 output_list = [result_tens.name]
3494 pCount, cCount = op["operands"]
3495 num_operands = pCount + cCount
3496 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
Matthew Haddone86fd342021-09-07 16:12:21 +01003497
Matthew Haddon848efb42021-09-09 12:30:53 +01003498 TosaErrorValidator.evValidateErrorIfs(
3499 self.ser,
3500 validator_fcns,
3501 error_name,
3502 op=op,
3503 mode=mode,
3504 shift=shift,
3505 input_dtype=input_dtype,
3506 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003507 input_shape=input.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01003508 output_shape=output_dims,
3509 offset=offset,
3510 offset_fp=offset_fp,
3511 stride=stride,
3512 stride_fp=stride_fp,
3513 input_list=input_list,
3514 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003515 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01003516 num_operands=num_operands,
3517 )
Matthew Haddone86fd342021-09-07 16:12:21 +01003518
Eric Kunzee5e26762020-10-13 16:11:07 -07003519 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08003520
Kevin Cheng550ccc52021-03-03 11:21:43 -08003521 attr.ResizeAttribute(
3522 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
3523 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003524
Matthew Haddon848efb42021-09-09 12:30:53 +01003525 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003526 return result_tens
3527
3528 def build_identityn(self, op, val, val2):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003529 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, None)
3530 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, None)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003531 self.ser.addOperator(
3532 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
3533 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003534 return result_tens
3535
Kevin Cheng17e92022021-10-01 14:33:33 -07003536 def build_const(self, op, val):
3537 self.ser.addOutputTensor(val)
3538 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07003539
3540 # Type Conversion
3541 def build_cast(self, op, val, out_dtype):
3542 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01003543 self.ser.addOperator(op['op'], [val.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003544 return result_tens
3545
Matthew Haddonc2025212021-10-08 21:21:05 +01003546 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel, validator_fcns, error_name):
Eric Kunzee5e26762020-10-13 16:11:07 -07003547 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
3548
3549 if per_channel:
3550 nc = val.shape[-1]
3551 else:
3552 nc = 1
3553
3554 in_type_width = self.typeWidth(val.dtype)
3555 out_type_width = self.typeWidth(out_dtype)
3556
Kevin Cheng3a478572021-01-22 17:21:02 -08003557 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003558 input_zp = self.randInt(-128, 128)
3559 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07003560 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003561 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07003562 in_type_width = in_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01003563 elif error_name == ErrorIf.InputZeroPointNotZero:
3564 input_zp = self.randInt(-128, 128)
3565 if input_zp == 0:
3566 input_zp = input_zp + self.rng.integers(1, 10)
3567 in_type_width = in_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003568 else:
3569 input_zp = 0
3570
Kevin Cheng3a478572021-01-22 17:21:02 -08003571 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003572 output_zp = self.randInt(-128, 128)
3573 out_type_width = out_type_width + 1
3574 elif out_dtype == DType.UINT8:
3575 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07003576 out_type_width = out_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01003577 elif error_name == ErrorIf.OutputZeroPointNotZero:
3578 output_zp = self.randInt(-128, 128)
3579 if output_zp == 0:
3580 output_zp = output_zp + self.rng.integers(1, 10)
3581 out_type_width = out_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003582 else:
3583 output_zp = 0
3584
3585 # Calculate scale based on:
3586 # scale = a *(2^output_width)/(2^input_width))
3587
3588 a = np.float32(self.rng.random(size=[nc]))
3589 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
3590
3591 if scale32:
3592 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01003593 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07003594 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
3595 else:
3596 # Cap the scaling at 2^15 - 1 for scale16
3597 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
3598
Kevin Cheng550ccc52021-03-03 11:21:43 -08003599 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07003600
3601 multiplier_arr = np.int32(np.zeros(shape=[nc]))
3602 shift_arr = np.int32(np.zeros(shape=[nc]))
3603
3604 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003605 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
3606 scale_arr[i], scale32
3607 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003608
Kevin Cheng550ccc52021-03-03 11:21:43 -08003609 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07003610
Matthew Haddonc2025212021-10-08 21:21:05 +01003611 # Invalidate Input/Output list for error if checks.
3612 input_list = [val.name]
3613 output_list = [result_tens.name]
3614 pCount, cCount = op["operands"]
3615 num_operands = pCount + cCount
3616 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3617
3618 qinfo = (input_zp, output_zp)
3619 TosaErrorValidator.evValidateErrorIfs(
3620 self.ser,
3621 validator_fcns,
3622 error_name,
3623 op=op,
3624 input_dtype=val.dtype,
3625 output_dtype=out_dtype,
3626 input_shape=val.shape,
3627 qinfo=qinfo,
3628 scale32 = scale32,
3629 double_round = double_round,
3630 input_list=input_list,
3631 output_list=output_list,
3632 result_tensor=result_tens,
3633 num_operands=num_operands,
3634 )
3635
Eric Kunzee5e26762020-10-13 16:11:07 -07003636 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003637 attr.RescaleAttribute(
3638 input_zp,
3639 output_zp,
3640 multiplier_arr,
3641 shift_arr,
3642 scale32,
3643 double_round,
3644 per_channel,
3645 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003646
Matthew Haddonc2025212021-10-08 21:21:05 +01003647 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003648 return result_tens
3649
3650 def build_cond_if_const(self, op, then_tens, else_tens, cond):
3651 # For cond_if with constants, we're supplied with then/else tensors that we ignore
3652 # (except for the generated shap) and the condition. Build Then/Else blocks
3653 # and fill them with const nodes for the body.
3654
3655 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003656 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07003657
3658 # Make then/else tensors
3659 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01003660 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
3661 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003662
3663 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08003664 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07003665
3666 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003667 then_block = "THEN_BLOCK"
3668 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003669 attr = ts.TosaSerializerAttribute()
3670 attr.CondIfAttribute(then_block, else_block)
3671
3672 # Finally, build the op and the two blocks
Matthew Haddon848efb42021-09-09 12:30:53 +01003673 self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003674
3675 self.ser.startBasicBlock(then_block)
3676 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003677 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003678 self.ser.addOutputTensor(then_tens)
3679
3680 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003681 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003682 self.ser.addOutputTensor(else_tens)
3683
3684 return result_tens
3685
3686 def build_cond_if_binary(self, op, a, b, cond):
3687 # For cond_if with a binary op in the then/else blocks, take a and b and
3688 # alternately add or subtract them based on the condition
3689
3690 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003691 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07003692
Kevin Cheng550ccc52021-03-03 11:21:43 -08003693 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003694
3695 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003696 then_block = "THEN_BLOCK"
3697 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003698 attr = ts.TosaSerializerAttribute()
3699 attr.CondIfAttribute(then_block, else_block)
3700
3701 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003702 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003703 op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08003704 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003705
Les Bell6040b4d2021-10-11 12:50:31 +01003706 if a.dtype in (DType.FLOAT, DType.INT32):
3707 then_op, else_op = Op.ADD, Op.SUB
3708 elif a.dtype in (DType.INT8, DType.INT16):
3709 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
3710 else:
3711 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07003712
Les Bell6040b4d2021-10-11 12:50:31 +01003713 for block, op in ((then_block, then_op), (else_block, else_op)):
3714 self.ser.startBasicBlock(block)
3715 self.ser.addInputTensor(a)
3716 self.ser.addInputTensor(b)
3717 tens = self.ser.addOutput(a.shape, a.dtype)
3718 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003719
3720 return result_tens
3721
3722 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003723 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07003724
Kevin Cheng550ccc52021-03-03 11:21:43 -08003725 cond_block = "COND_BLOCK"
3726 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003727
3728 attr = ts.TosaSerializerAttribute()
3729 attr.WhileLoopAttribute(cond_block, body_block)
3730
3731 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003732 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003733 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08003734 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07003735
3736 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003737 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
3738 a_out = self.ser.addIntermediate(a.shape, a.dtype)
3739 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003740
3741 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08003742 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003743 op['op'],
Kevin Cheng550ccc52021-03-03 11:21:43 -08003744 [iter.name, a.name, acc.name],
3745 [iter_out.name, a_out.name, acc_out.name],
3746 attr,
3747 )
Kevin Chengb227ae52021-09-02 13:43:17 -07003748 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07003749
3750 # COND block (input: iter, output: cond_tens )
3751 self.ser.startBasicBlock(cond_block)
3752 self.ser.addInputTensor(iter)
3753 self.ser.addInputTensor(a)
3754 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003755 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
3756 cond_tens = self.ser.addOutput([], DType.BOOL)
3757 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003758
3759 # BODY block (input: a, acc, iter, output: a, acc, iter)
3760 # Note that local intermediate tensors need to be declared here for the outputs
3761 self.ser.startBasicBlock(body_block)
3762 self.ser.addInputTensor(iter)
3763 self.ser.addInputTensor(a)
3764 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003765 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
3766 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
3767 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003768 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
3769 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
3770 self.ser.addOutputTensor(iter_body_out)
3771 self.ser.addOutputTensor(a)
3772 self.ser.addOutputTensor(acc_body_out)
3773
3774 return acc_out
3775
Matthew Haddon1c00b712021-10-01 15:51:03 +01003776 def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
3777 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
3778 default_test_rank_range = range(1, 5)
3779 if not shapeFilter:
3780 shapeFilter = [None]
3781
3782 # Calculate the filters based on what is requested and what the operator allows
3783 rmin, rmax = op["rank"]
3784 if rankFilter is not None:
3785 cleanRankFilter = []
3786 # Ensure rankFilter values are allowed by operator
3787 for rank in rankFilter:
3788 if rank >= rmin and rank <= rmax:
3789 cleanRankFilter.append(rank)
3790 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01003791 # Ensure default behaviour is bounded by default range or by operator,
3792 # whichever is the smaller range of ranks.
3793 opRankRange = range(rmin, rmax + 1)
3794 cleanRankFilter = opRankRange if len(opRankRange) <= len(default_test_rank_range) else default_test_rank_range
Matthew Haddon1c00b712021-10-01 15:51:03 +01003795 else:
3796 cleanRankFilter = range(rmin, rmax + 1)
3797
3798 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003799
Matthew Haddon1c00b712021-10-01 15:51:03 +01003800 if dtypeFilter is not None:
3801 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01003802 # Create list of operator dtypes filtered by requested dtypes
3803 for dtype in dtypes:
3804 if dtype in dtypeFilter or (isinstance(dtype, list) and dtype[0] in dtypeFilter):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003805 cleanDtypeFilter.append(dtype)
3806 else:
3807 cleanDtypeFilter = dtypes
3808
3809 if testType == 'positive':
3810 filterDict = {
3811 'shapeFilter': shapeFilter,
3812 'rankFilter': cleanRankFilter,
3813 'dtypeFilter': cleanDtypeFilter
3814 }
3815 return filterDict
3816 elif testType == 'negative':
Matthew Haddone807aae2021-10-11 18:12:58 +01003817 if validator is not None:
3818 validator_info = validator(check=False, op=op)
3819 else:
3820 return None
3821
Matthew Haddon1c00b712021-10-01 15:51:03 +01003822 error_arguments = validator_info['param_reqs']
3823
3824 #Set parameters as required
3825 if error_arguments['rank'] != None:
3826 rankFilter = error_arguments['rank']
3827 else:
3828 rankFilter = cleanRankFilter
3829
3830 if error_arguments['dtype'] != None:
3831 dtypeFilter = error_arguments['dtype']
3832 else:
3833 dtypeFilter = cleanDtypeFilter
3834
3835 if error_arguments['shape'] != None:
3836 shapeFilter = error_arguments['shape']
3837 else:
3838 shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small
3839
3840 filterDict = {
3841 'shapeFilter': shapeFilter,
3842 'rankFilter': rankFilter,
3843 'dtypeFilter': dtypeFilter
3844 }
3845 return filterDict
3846
3847
Kevin Cheng550ccc52021-03-03 11:21:43 -08003848 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01003849 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08003850 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003851
3852 try:
3853 op = self.TOSA_OP_LIST[opName]
3854 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003855 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003856
3857 # Initialize a new random number generator
3858 self.rng = np.random.default_rng(self.random_seed)
3859
Kevin Cheng550ccc52021-03-03 11:21:43 -08003860 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003861
Eric Kunzee5e26762020-10-13 16:11:07 -07003862 # Test list consists of a tuple of:
3863 # (opName, testNameStr, dtype, shapeList, argumentsList)
3864 testList = []
Matthew Haddon1c00b712021-10-01 15:51:03 +01003865 if testType == 'negative' and "error_if_validators" in op:
3866 error_if_validators = op["error_if_validators"]
3867 else:
3868 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07003869
Matthew Haddon1c00b712021-10-01 15:51:03 +01003870 for validator in error_if_validators:
3871 if validator is not None:
3872 error_name = validator(check=False, op=op)['error_name']
Matthew Haddon1c00b712021-10-01 15:51:03 +01003873 else:
3874 error_name = None
3875
3876 filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator)
Matthew Haddone807aae2021-10-11 18:12:58 +01003877 if filterDict == None:
3878 return []
Matthew Haddon1c00b712021-10-01 15:51:03 +01003879 cleanRankFilter = filterDict['rankFilter']
3880 cleanDtypeFilter = filterDict['dtypeFilter']
3881 cleanShapeFilter = filterDict['shapeFilter']
3882 #print(f"Filters: S {shapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
3883
3884 for r in cleanRankFilter:
Kevin Cheng1533b852021-09-01 12:51:58 -07003885 if opName.startswith("conv3d"):
3886 assert r == 5, "conv3d test must have input rank == 5"
Matthew Haddon1c00b712021-10-01 15:51:03 +01003887 for t in cleanDtypeFilter:
3888 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01003889 # Filter out by rank
3890 if shape is not None and len(shape) != r:
3891 continue
Matthew Haddon74567092021-07-16 15:38:20 +01003892 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003893 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003894
Matthew Haddon74567092021-07-16 15:38:20 +01003895 shapeStr = self.shapeStr(shapeList[0])
3896 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07003897
Matthew Haddon74567092021-07-16 15:38:20 +01003898 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
3899 argList = []
3900 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01003901 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003902 else:
Matthew Haddon74567092021-07-16 15:38:20 +01003903 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07003904
Matthew Haddon74567092021-07-16 15:38:20 +01003905 for argStr, args in argList:
Matthew Haddon1c00b712021-10-01 15:51:03 +01003906 if testType == 'positive':
3907 if argStr:
3908 testStr = "{}_{}_{}_{}".format(
3909 opName, shapeStr, typeStr, argStr
3910 )
3911 else:
3912 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
3913 elif testType == 'negative':
Matthew Haddone86fd342021-09-07 16:12:21 +01003914 if argStr:
3915 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
3916 opName, error_name, shapeStr, typeStr, argStr
3917 )
3918 else:
3919 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003920
3921 testList.append((opName, testStr, t, error_name, shapeList, args))
3922
3923 if testType == 'positive':
3924 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
3925 if "invalid_test_validators" in op:
3926 invalid_test_validators = op["invalid_test_validators"]
3927 clean_testList = []
3928 for test in testList:
3929 for validator_fcn in invalid_test_validators:
3930 remove_test = False
3931 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
3932 remove_test = True
3933 if not remove_test:
3934 clean_testList.append(test)
3935 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07003936
3937 return testList
3938
Matthew Haddone86fd342021-09-07 16:12:21 +01003939
3940 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07003941 try:
3942 op = self.TOSA_OP_LIST[opName]
3943 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003944 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003945
3946 # Create a serializer
3947 self.createSerializer(opName, testStr)
3948
Kevin Cheng550ccc52021-03-03 11:21:43 -08003949 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003950 if "error_if_validators" in op:
3951 error_if_validators = op["error_if_validators"]
3952 else:
3953 error_if_validators = None
3954
Kevin Cheng550ccc52021-03-03 11:21:43 -08003955 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003956 num_operands = pCount + cCount
3957
3958 if isinstance(dtype_or_dtypeList, list):
3959 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07003960 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01003961 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003962 else:
3963 dtypeList = [dtype_or_dtypeList] * (num_operands)
3964
Kevin Cheng93a16282021-08-31 16:14:03 -07003965 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01003966 assert (
3967 len(shapeList) == num_operands
3968 ), "shapeList length {} must match number of operands {}".format(
3969 len(shapeList), num_operands
3970 )
3971 assert (
3972 len(dtypeList) == num_operands
3973 ), "dtypeList length {} must match number of operands {}".format(
3974 len(dtypeList), num_operands
3975 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003976
3977 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003978 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003979 except KeyError:
3980 qgen = None
3981
3982 # Build the random tensor operands and the test
3983 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08003984
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003985 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003986
3987 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003988 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003989 else:
3990 qinfo = None
3991
3992 try:
3993 if error_if_validators is None:
3994 if qinfo is not None:
3995 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
3996 else:
3997 resultName = build_fcn(self, op, *tens, *testArgs)
3998 else:
3999 if qinfo is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004000 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name, qinfo)
Matthew Haddon1c00b712021-10-01 15:51:03 +01004001 else:
4002 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name)
4003 except TypeError as e:
4004 print(
4005 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
4006 build_fcn, tens, testArgs
4007 )
4008 )
4009 raise e
4010
4011 if resultName is None:
4012 print("Invalid ERROR_IF tests created")
4013
4014 # Save the serialized test
4015 self.serialize("test")
4016
4017
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004018 def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
Matthew Haddon1c00b712021-10-01 15:51:03 +01004019 pCount, cCount = op["operands"]
4020
4021 tens = []
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004022 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name == None:
Jeremy Johnsonef509a42021-09-07 13:59:47 +01004023 # Make sure the operation does not cause value saturation - where
4024 # the number wraps due to limited number of bits to store the answer
4025 assert (
4026 pCount == 2 and cCount == 0
4027 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01004028 placeholders = []
4029 add = (op["op"] == Op.ADD)
4030 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
4031 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
4032 if add:
4033 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
4034 else:
4035 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
4036
4037 # Work out the saturation limits
4038 max_i32 = (1 << 31)-1
4039 min_i32 = -(1 << 31)
4040 max_arr = np.full(shapeList[1], max_i32)
4041 min_arr = np.full(shapeList[1], min_i32)
4042
4043 # Find how much values exceed the maximum/minimums
4044 sat_max_arr = np.maximum(res_arr - max_arr, 0)
4045 sat_min_arr = np.minimum(res_arr - min_arr, 0)
4046
4047 if not add:
4048 # Swap saturation values and negate values as we need to perform opposite operations
4049 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
4050
4051 # Create new array of unsaturated values by clipping values as needed
4052 b_unsat_arr = b_arr
4053 if (sat_max_arr != 0).any():
4054 # Clip values that cause saturation
4055 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
4056 # Reduce axes in unsaturated tensor to match original tensor
4057 for axis, dim in enumerate(b_arr.shape):
4058 if dim != b_unsat_arr.shape[axis]:
4059 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
4060 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
4061
4062 if (sat_min_arr != 0).any():
4063 # Clip values that cause saturation
4064 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
4065 # Reduce axes in unsaturated tensor to match original tensor
4066 for axis, dim in enumerate(b_arr.shape):
4067 if dim != b_unsat_arr.shape[axis]:
4068 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
4069 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
4070
4071 placeholders.append(
4072 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
4073 )
4074 placeholders.append(
4075 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
4076 )
4077
4078 tens.extend(placeholders)
Jeremy Johnson8c06a652021-10-20 15:51:11 +01004079 elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[0] == DType.INT32:
4080 # Limit input tensors with cond_if_binary or while_loop to stop
4081 # saturation of add/sub ops
4082 pRemain = pCount
4083 placeholders = []
4084 for idx, shape in enumerate(shapeList[:]):
4085 arr = self.getRandTensor(shapeList[idx], DType.INT16)
4086 if pRemain > 0:
4087 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
4088 pRemain -= 1
4089 else:
4090 placeholders.append(self.ser.addConst(shape, dtypeList[idx], arr))
4091
4092 tens.extend(placeholders)
Jeremy Johnsonef509a42021-09-07 13:59:47 +01004093 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
4094 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004095 assert (
4096 pCount == 2 and cCount == 0
4097 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08004098
4099 placeholders = []
4100 for idx, shape in enumerate(shapeList[:]):
4101 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07004102 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08004103 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07004104 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08004105 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07004106 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08004107 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
4108 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004109 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08004110 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004111 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07004112 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08004113
4114 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01004115 elif op["op"] == Op.SELECT:
4116 # Set datatype of condition tensor to boolean
4117 dtypeList[0] = DType.BOOL
4118 tens.extend(
4119 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
4120 )
4121 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004122 elif op["op"] == Op.INTDIV and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004123 assert (
4124 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01004125 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004126
4127 placeholders = []
4128
Matthew Haddon459443c2021-08-23 16:43:13 +01004129 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004130 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07004131 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004132 while True:
4133 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
4134 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
4135
4136 if (divisor_arr == 0).any():
4137 continue
4138
Kevin Cheng47315e12021-05-13 17:41:28 -07004139 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004140 continue
4141
4142 break
4143
4144 placeholders.append(
4145 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
4146 )
4147 placeholders.append(
4148 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
4149 )
4150
4151 tens.extend(placeholders)
4152 elif op["op"] == Op.MUL:
4153 assert (
4154 pCount == 2 and cCount == 0
4155 ), "Op.MUL must have 2 placeholders, 0 consts"
4156
4157 if dtypeList[0] == DType.FLOAT:
4158 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
4159 else:
4160 placeholders = []
4161
4162 # Make sure multiply result in int32 range
4163 shift = testArgs[0]
4164 if dtypeList[0] == DType.INT8:
4165 num_bits = 8
4166 elif dtypeList[0] == DType.INT16:
4167 num_bits = 16
4168 elif dtypeList[0] == DType.INT32:
4169 num_bits = 32
4170 else:
4171 raise Exception("OpMul: invalid input dtype")
4172
4173 for idx, shape in enumerate(shapeList[:]):
4174 low = -(2 ** (num_bits - 1))
4175 high = (2 ** (num_bits - 1)) - 1
4176
4177 a_arr = np.int32(
4178 self.rng.integers(low=low, high=high, size=shapeList[0])
4179 )
4180 b_arr = np.int32(
4181 self.rng.integers(low=low, high=high, size=shapeList[1])
4182 )
4183
4184 i = 0
4185 while True:
4186
4187 a_arr_64 = a_arr.astype(np.int64)
4188 b_arr_64 = b_arr.astype(np.int64)
4189
4190 if shift > 0:
4191 rounding = 1 << (shift - 1)
4192 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
4193 else:
4194 result_arr = a_arr_64 * b_arr_64
4195
4196 if (result_arr > -(2 ** 31)).all() and (
4197 result_arr <= ((2 ** 31) - 1)
4198 ).all():
4199 break
4200
4201 i = i + 1
4202 a_arr = a_arr // 2
4203 b_arr = b_arr // 2
4204
4205 placeholders.append(
4206 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
4207 )
4208 placeholders.append(
4209 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
4210 )
4211
4212 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01004213 elif op["op"] == Op.CONCAT:
4214 count = len(shapeList) - self.args.num_const_inputs_concat
4215 if count < 1:
4216 count = 1
4217 if self.args.num_const_inputs_concat == 0:
4218 count = len(shapeList)
4219
4220 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
4221 tens.extend(
4222 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
4223 )
4224 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08004225 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07004226 tens.extend(
4227 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
4228 )
4229 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07004230
Matthew Haddon1c00b712021-10-01 15:51:03 +01004231 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07004232
4233 def createDynamicOpLists(self):
4234
4235 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07004236 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004237
Kevin Cheng1533b852021-09-01 12:51:58 -07004238 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004239 testName = "conv2d_{}x{}".format(k[0], k[1])
4240 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
4241 self.TOSA_OP_LIST[testName]["filter"] = k
4242 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07004243
Kevin Cheng550ccc52021-03-03 11:21:43 -08004244 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
4245 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
4246 "depthwise_conv2d_TEMPLATE"
4247 ].copy()
4248 self.TOSA_OP_LIST[testName]["filter"] = k
4249 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07004250
Kevin Cheng550ccc52021-03-03 11:21:43 -08004251 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
4252 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
4253 "transpose_conv2d_TEMPLATE"
4254 ].copy()
4255 self.TOSA_OP_LIST[testName]["filter"] = k
4256 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07004257
Kevin Cheng1533b852021-09-01 12:51:58 -07004258 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
4259 for k in KERNELS_3D:
4260 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
4261 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
4262 self.TOSA_OP_LIST[testName]["filter"] = k
4263 self.TOSA_OP_LIST[testName]["template"] = False
4264
Eric Kunzee5e26762020-10-13 16:11:07 -07004265 # Delete any templates after having created any dynamic ops
4266 # This is a two-pass operation because it's bad practice to delete
4267 # keys from dictionaries while iterating
4268 keyList = []
4269 for k in self.TOSA_OP_LIST:
4270 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004271 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07004272 keyList.append(k)
4273 continue
4274 except KeyError:
4275 pass
4276
4277 for k in keyList:
4278 del self.TOSA_OP_LIST[k]
4279
4280 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004281 """Fill in default fields for ops if they aren't already specified.
4282 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07004283 for op in self.TOSA_OP_LIST:
4284
4285 # Required fields
4286 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004287 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004288 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004289 raise Exception(
4290 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
4291 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004292
4293 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004294 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004295 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004296 raise Exception(
4297 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
4298 op
4299 )
4300 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004301
4302 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004303 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004304 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004305 raise Exception(
4306 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
4307 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004308
4309 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004310 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004311 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004312 raise Exception(
4313 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
4314 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004315
4316 # Put in default rank range, if missing
4317 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004318 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004319 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004320 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07004321
4322 # Tensor operator list
4323 # 'op': op name
4324 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08004325 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
4326 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07004327 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
4328 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08004329 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07004330
Kevin Cheng550ccc52021-03-03 11:21:43 -08004331 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
4332 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07004333
Kevin Cheng550ccc52021-03-03 11:21:43 -08004334 TYPE_BOOL = [DType.BOOL]
4335 TYPE_FI32 = [DType.FLOAT, DType.INT32]
4336 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
4337 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07004338
Kevin Cheng550ccc52021-03-03 11:21:43 -08004339 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07004340
Kevin Cheng1533b852021-09-01 12:51:58 -07004341 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07004342 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07004343 [DType.INT8, DType.INT8, DType.INT32],
4344 [DType.INT16, DType.INT8, DType.INT48],
4345 DType.FLOAT,
4346 ]
4347
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01004348 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07004349
4350 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08004351 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004352 "argmax": {
4353 "op": Op.ARGMAX,
4354 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004355 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004356 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4357 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004358 "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evArgmaxOutputRankMismatch,
4359 TosaErrorValidator.evArgmaxOutputShapeMismatch, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
4360 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004361 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004362 "avg_pool2d": {
4363 "op": Op.AVG_POOL2D,
4364 "operands": (1, 0),
4365 "rank": (4, 4),
4366 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
4367 "qgen": TosaQuantGen.qgUnary,
4368 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004369 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
4370 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
4371 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
4372 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
4373 TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08004374 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004375 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08004376 "conv2d_TEMPLATE": {
4377 "op": Op.CONV2D,
4378 "operands": (1, 2),
4379 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01004380 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004381 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004382 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01004383 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004384 "template": True,
4385 },
Kevin Cheng1533b852021-09-01 12:51:58 -07004386 # Templated operator. Filled in by createDynamicOpLists
4387 "conv3d_TEMPLATE": {
4388 "op": Op.CONV3D,
4389 "operands": (1, 2),
4390 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01004391 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07004392 "qgen": TosaQuantGen.qgConv,
4393 "types": TYPE_CONV,
4394 "template": True,
4395 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004396 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08004397 "depthwise_conv2d_TEMPLATE": {
4398 "op": Op.DEPTHWISE_CONV2D,
4399 "operands": (1, 2),
4400 "filter": [1, 1],
4401 "rank": (4, 4),
4402 "build_fcn": (
4403 build_depthwise_conv2d,
4404 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01004405 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004406 ),
4407 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004408 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01004409 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004410 "template": True,
4411 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004412 "fully_connected": {
4413 "op": Op.FULLY_CONNECTED,
4414 "operands": (1, 2),
4415 "rank": (2, 2),
4416 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
4417 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004418 "types": TYPE_CONV,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004419 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evWrongRank,
4420 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004421 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004422 "matmul": {
4423 "op": Op.MATMUL,
4424 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07004425 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08004426 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
4427 "qgen": TosaQuantGen.qgMatmul,
4428 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004429 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
4430 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004431 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004432 "max_pool2d": {
4433 "op": Op.MAX_POOL2D,
4434 "operands": (1, 0),
4435 "rank": (4, 4),
4436 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
4437 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004438 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
4439 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
4440 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
4441 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08004442 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004443 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08004444 "transpose_conv2d_TEMPLATE": {
4445 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07004446 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004447 "rank": (4, 4),
4448 "build_fcn": (
4449 build_transpose_conv2d,
4450 TosaTensorGen.tgTransposeConv2D,
4451 TosaArgGen.agTransposeConv2D,
4452 ),
4453 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004454 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01004455 "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004456 "template": True,
4457 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004458 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08004459 "clamp": {
4460 "op": Op.CLAMP,
4461 "operands": (1, 0),
4462 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
4463 "types": TYPE_NARROW_INT_FP,
4464 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004465 "sigmoid": {
4466 "op": Op.SIGMOID,
4467 "operands": (1, 0),
4468 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
4469 "types": TYPE_FP,
4470 },
4471 "tanh": {
4472 "op": Op.TANH,
4473 "operands": (1, 0),
4474 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
4475 "types": TYPE_FP,
4476 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004477 # Elementwise Binary Operators
4478 "add": {
4479 "op": Op.ADD,
4480 "operands": (2, 0),
4481 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4482 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004483 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4484 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004485 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004486 "arithmetic_right_shift": {
4487 "op": Op.ARITHMETIC_RIGHT_SHIFT,
4488 "operands": (2, 0),
4489 "build_fcn": (
4490 build_arithmetic_right_shift,
4491 TosaTensorGen.tgBroadcastFuzz,
4492 TosaArgGen.agArithmeticRightShift,
4493 ),
4494 "types": TYPE_INT,
4495 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004496 "bitwise_and": {
4497 "op": Op.BITWISE_AND,
4498 "operands": (2, 0),
4499 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4500 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004501 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4502 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004503 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004504 "bitwise_or": {
4505 "op": Op.BITWISE_OR,
4506 "operands": (2, 0),
4507 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4508 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004509 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4510 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004511 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004512 "bitwise_xor": {
4513 "op": Op.BITWISE_XOR,
4514 "operands": (2, 0),
4515 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4516 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004517 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4518 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004519 },
Matthew Haddon459443c2021-08-23 16:43:13 +01004520 "intdiv": {
4521 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004522 "operands": (2, 0),
4523 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4524 "types": [DType.INT32],
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004525 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4526 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004527 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004528 "logical_and": {
4529 "op": Op.LOGICAL_AND,
4530 "operands": (2, 0),
4531 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4532 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004533 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4534 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004535 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004536 "logical_left_shift": {
4537 "op": Op.LOGICAL_LEFT_SHIFT,
4538 "operands": (2, 0),
4539 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4540 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004541 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4542 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004543 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004544 "logical_right_shift": {
4545 "op": Op.LOGICAL_RIGHT_SHIFT,
4546 "operands": (2, 0),
4547 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4548 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004549 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4550 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004551 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004552 "logical_or": {
4553 "op": Op.LOGICAL_OR,
4554 "operands": (2, 0),
4555 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4556 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004557 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4558 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004559 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004560 "logical_xor": {
4561 "op": Op.LOGICAL_XOR,
4562 "operands": (2, 0),
4563 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4564 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004565 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4566 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004567 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004568 "maximum": {
4569 "op": Op.MAXIMUM,
4570 "operands": (2, 0),
4571 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4572 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004573 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4574 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004575 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004576 "minimum": {
4577 "op": Op.MINIMUM,
4578 "operands": (2, 0),
4579 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4580 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004581 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4582 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004583 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004584 "mul": {
4585 "op": Op.MUL,
4586 "operands": (2, 0),
4587 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
4588 "types": TYPE_INT_FP,
4589 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004590 "pow": {
4591 "op": Op.POW,
4592 "operands": (2, 0),
4593 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
4594 "types": TYPE_FP,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004595 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4596 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004597 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004598 "sub": {
4599 "op": Op.SUB,
4600 "operands": (2, 0),
4601 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4602 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004603 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4604 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004605 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004606 "table": {
4607 "op": Op.TABLE,
4608 # Use the automatic generation functions to create the input array
4609 # but create the table tensor in the build function, as it may be
4610 # a different type from the input
4611 "operands": (1, 0),
Kevin Chengfe392ce2021-10-18 21:51:55 +00004612 "build_fcn": (build_table, TosaTensorGen.tgBasic, TosaArgGen.agTable),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004613 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08004614 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004615 # Elementwise Unary operators
4616 "abs": {
4617 "op": Op.ABS,
4618 "operands": (1, 0),
4619 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4620 "types": TYPE_FI32,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004621 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4622 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004623 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004624 "bitwise_not": {
4625 "op": Op.BITWISE_NOT,
4626 "operands": (1, 0),
4627 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4628 "types": TYPE_INT,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004629 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4630 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004631 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004632 "ceil": {
4633 "op": Op.CEIL,
4634 "operands": (1, 0),
4635 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4636 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004637 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4638 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004639 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004640 "clz": {
4641 "op": Op.CLZ,
4642 "operands": (1, 0),
4643 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4644 "types": [DType.INT32],
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004645 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4646 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004647 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004648 "exp": {
4649 "op": Op.EXP,
4650 "operands": (1, 0),
4651 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4652 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004653 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4654 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004655 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004656 "floor": {
4657 "op": Op.FLOOR,
4658 "operands": (1, 0),
4659 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4660 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004661 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4662 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004663 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004664 "log": {
4665 "op": Op.LOG,
4666 "operands": (1, 0),
4667 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4668 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004669 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4670 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004671 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004672 "logical_not": {
4673 "op": Op.LOGICAL_NOT,
4674 "operands": (1, 0),
4675 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4676 "types": TYPE_BOOL,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004677 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4678 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004679 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004680 "negate": {
4681 "op": Op.NEGATE,
4682 "operands": (1, 0),
4683 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4684 "qgen": TosaQuantGen.qgUnary,
4685 "types": TYPE_INT_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004686 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
4687 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
4688 TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004689 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004690 "reciprocal": {
4691 "op": Op.RECIPROCAL,
4692 "operands": (1, 0),
4693 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4694 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004695 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4696 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004697 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004698 "rsqrt": {
4699 "op": Op.RSQRT,
4700 "operands": (1, 0),
4701 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4702 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004703 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4704 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004705 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004706 # Elementwise Ternary operators
4707 "select": {
4708 "op": Op.SELECT,
4709 "operands": (3, 0),
4710 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
4711 "types": TYPE_FIB,
4712 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004713 # Comparison operators
4714 "equal": {
4715 "op": Op.EQUAL,
4716 "operands": (2, 0),
4717 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4718 "types": TYPE_FI32,
4719 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004720 "greater_equal": {
4721 "op": Op.GREATER_EQUAL,
4722 "operands": (2, 0),
4723 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4724 "types": TYPE_FI32,
4725 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004726 "greater": {
4727 "op": Op.GREATER,
4728 "operands": (2, 0),
4729 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4730 "types": TYPE_FI32,
4731 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004732 # Reduction operators
4733 "reduce_all": {
4734 "op": Op.REDUCE_ALL,
4735 "operands": (1, 0),
4736 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4737 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004738 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4739 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4740 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004741 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004742 "reduce_any": {
4743 "op": Op.REDUCE_ANY,
4744 "operands": (1, 0),
4745 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4746 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004747 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4748 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4749 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004750 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004751 "reduce_max": {
4752 "op": Op.REDUCE_MAX,
4753 "operands": (1, 0),
4754 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4755 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004756 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4757 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4758 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004759 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004760 "reduce_min": {
4761 "op": Op.REDUCE_MAX,
4762 "operands": (1, 0),
4763 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4764 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004765 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4766 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4767 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004768 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004769 "reduce_product": {
4770 "op": Op.REDUCE_PRODUCT,
4771 "operands": (1, 0),
4772 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4773 "types": TYPE_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004774 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4775 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4776 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004777 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004778 "reduce_sum": {
4779 "op": Op.REDUCE_SUM,
4780 "operands": (1, 0),
4781 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4782 "types": TYPE_FI32,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004783 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4784 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4785 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004786 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004787 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004788 "concat": {
4789 "op": Op.CONCAT,
4790 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01004791 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004792 "types": TYPE_FIB,
4793 },
4794 "pad": {
4795 "op": Op.PAD,
4796 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01004797 "rank": (1, 5),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004798 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
4799 "qgen": TosaQuantGen.qgPad,
4800 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004801 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evPadSmallerZero,
4802 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004803 },
4804 "reshape": {
4805 "op": Op.RESHAPE,
4806 "operands": (1, 0),
4807 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
4808 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004809 "error_if_validators": (TosaErrorValidator.evTensorSizeInputOutputMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4810 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004811 },
4812 "reverse": {
4813 "op": Op.REVERSE,
4814 "operands": (1, 0),
4815 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4816 "types": TYPE_FIB,
4817 },
4818 "slice": {
4819 "op": Op.SLICE,
4820 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01004821 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004822 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
4823 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004824 "error_if_validators": (TosaErrorValidator.evStartSmallerZero, TosaErrorValidator.evSizeSmallerEqualZero, TosaErrorValidator.evStartSizeOutsideBounds,
4825 TosaErrorValidator.evSizeOutputShapeMismatch, TosaErrorValidator.evInputSizeStartLengthMismatch, TosaErrorValidator.evWrongRank,
4826 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004827 },
4828 "tile": {
4829 "op": Op.TILE,
4830 "operands": (1, 0),
4831 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
4832 "types": TYPE_FIB,
4833 },
4834 "transpose": {
4835 "op": Op.TRANSPOSE,
4836 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01004837 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004838 "build_fcn": (
4839 build_transpose,
4840 TosaTensorGen.tgBasic,
4841 TosaArgGen.agTranspose,
4842 ),
4843 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01004844 "error_if_validators": (TosaErrorValidator.evIndexOutsideBounds, TosaErrorValidator.evIndexUsedTwice, TosaErrorValidator.evWrongRank,
4845 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004846 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004847 # Data nodes
4848 "const": {
4849 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004850 "operands": (0, 1),
4851 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08004852 "types": TYPE_FIB,
4853 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004854 "identity": {
4855 "op": Op.IDENTITY,
4856 "operands": (1, 0),
4857 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4858 "types": TYPE_FIB,
4859 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004860 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004861 "gather": {
4862 "op": Op.GATHER,
4863 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4864 "operands": (1, 0),
4865 "rank": (3, 3),
4866 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
4867 "types": TYPE_INT_FP,
4868 },
4869 "scatter": {
4870 "op": Op.SCATTER,
4871 # Only specify 'values_in' tensor here.
4872 #'indices' and 'input' are generated in op building stage
4873 "operands": (2, 0),
4874 "rank": (3, 3),
4875 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
4876 "types": TYPE_INT_FP,
4877 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004878 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004879 "resize": {
4880 "op": Op.RESIZE,
4881 "operands": (1, 0),
4882 "rank": (4, 4),
4883 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
4884 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01004885 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
4886 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
4887 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
Matthew Haddon848efb42021-09-09 12:30:53 +01004888 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004889 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
4890 TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004891 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004892 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004893 "cast": {
4894 "op": Op.CAST,
4895 "operands": (1, 0),
4896 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
4897 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
4898 },
4899 "rescale": {
4900 "op": Op.RESCALE,
4901 "operands": (1, 0),
Matthew Haddonc2025212021-10-08 21:21:05 +01004902 "rank": (1,4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004903 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004904 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Matthew Haddonc2025212021-10-08 21:21:05 +01004905 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, TosaErrorValidator.evScaleTrue,
4906 TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4907 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004908 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004909 # Custom
4910 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004911 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004912 # Two varients of cond_if, one that generates one of two constant tensors (no
4913 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4914 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004915 "cond_if_const": {
4916 "op": Op.COND_IF,
4917 "operands": (0, 2),
4918 "build_fcn": (
4919 build_cond_if_const,
4920 TosaTensorGen.tgBasic,
4921 TosaArgGen.agCondIf,
4922 ),
4923 "types": [DType.BOOL],
4924 },
4925 "cond_if_binary": {
4926 "op": Op.COND_IF,
4927 "operands": (2, 0),
4928 "build_fcn": (
4929 build_cond_if_binary,
4930 TosaTensorGen.tgBasic,
4931 TosaArgGen.agCondIf,
4932 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004933 "types": TYPE_INT_FP,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004934 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004935 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004936 "while_loop": {
4937 "op": Op.WHILE_LOOP,
4938 "operands": (0, 1),
4939 "build_fcn": (
4940 build_while_loop,
4941 TosaTensorGen.tgBasic,
4942 TosaArgGen.agWhileLoop,
4943 ),
4944 "types": [DType.INT32],
4945 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004946 }
4947
Kevin Cheng550ccc52021-03-03 11:21:43 -08004948
Eric Kunzee5e26762020-10-13 16:11:07 -07004949class OutputShaper:
4950 # Methods in this class compute the expected output shape and datatype
4951 # for common classes of operations
4952 def __init__(self):
4953 pass
4954
4955 # These methods return arguments that can be used for
4956 # creating a new output tensor
4957 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004958 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4959 if error_name != ErrorIf.RankMismatch:
4960 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004961 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004962
4963 shape = []
4964 for i in range(len(a.shape)):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004965 if a.shape[i] == 1 and error_name == None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004966 shape.append(b.shape[i])
4967 else:
4968 shape.append(a.shape[i])
4969
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004970 if error_name == ErrorIf.WrongOutputType:
4971 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4972 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4973 outputDType = rng.choice(wrong_dtypes)
4974 else:
4975 outputDType = a.dtype
4976
4977 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004978
4979 @staticmethod
4980 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004981 assert len(a.shape) == len(b.shape)
4982 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004983
4984 shape = []
4985 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004986 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004987 shape.append(a.shape[i])
4988
Kevin Cheng550ccc52021-03-03 11:21:43 -08004989 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004990
4991 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004992 def unaryOp(ser, rng, a, error_name=None):
4993 if error_name == ErrorIf.WrongOutputType:
4994 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4995 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4996 outputDType = rng.choice(wrong_dtypes)
4997 else:
4998 outputDType = a.dtype
4999
5000 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005001
5002 @staticmethod
5003 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005004 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
5005 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005006
5007 shape = []
5008 for i in range(len(a.shape)):
5009 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5010
Kevin Cheng550ccc52021-03-03 11:21:43 -08005011 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005012
5013 @staticmethod
5014 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005015 assert len(a.shape) == len(b.shape)
5016 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005017
5018 # Do broadcast
5019 shape = []
5020 for i in range(len(a.shape)):
5021 if a.shape[i] == 1:
5022 shape.append(b.shape[i])
5023 else:
5024 shape.append(a.shape[i])
5025
5026 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08005027 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07005028
5029 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005030 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005031 shape = a.shape.copy()
Matthew Haddond6ce7252021-09-29 15:35:44 +01005032 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ShapeOfAxisNotOne]:
5033 shape[axis] = 1
5034 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5035 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005036
Matthew Haddond6ce7252021-09-29 15:35:44 +01005037 if error_name == ErrorIf.WrongOutputType:
5038 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5039 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5040 outputDType = rng.choice(wrong_dtypes)
5041 else:
5042 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005043
Matthew Haddond6ce7252021-09-29 15:35:44 +01005044 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005045
5046 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005047 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005048 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005049
5050 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5051 del shape[axis]
5052
5053 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5054 remove = rng.choice([True, False])
5055 if remove and len(shape) > 1:
5056 del shape[0]
5057 else:
5058 shape.append(1)
5059 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5060 for i in range(len(shape)):
5061 shape[i] = shape[i] + rng.integers(1, 10)
5062
5063 if error_name == ErrorIf.WrongOutputType:
5064 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5065 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5066 outputDType = rng.choice(wrong_dtypes)
5067 else:
5068 outputDType = DType.INT32
5069
5070 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005071
5072 @staticmethod
5073 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
5074
5075 # IFM: NHWC
5076 # Filter: OHWI
5077 # OFM: NHWC
5078
5079 if len(padding) == 2:
5080 # Expand padding to 4 parameters in the case of transpose_conv2d
5081 # From H,W to T,B,L,R
5082 padding = [padding[0], padding[0], padding[1], padding[1]]
5083
Kevin Cheng550ccc52021-03-03 11:21:43 -08005084 h = (
5085 ifm.shape[1]
5086 - filter.shape[1]
5087 - (filter.shape[1] - 1) * (dilations[0] - 1)
5088 + padding[0]
5089 + padding[1]
5090 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005091
Kevin Cheng550ccc52021-03-03 11:21:43 -08005092 w = (
5093 ifm.shape[2]
5094 - filter.shape[2]
5095 - (filter.shape[2] - 1) * (dilations[1] - 1)
5096 + padding[2]
5097 + padding[3]
5098 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005099
Eric Kunzee5e26762020-10-13 16:11:07 -07005100 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5101
Kevin Cheng3a478572021-01-22 17:21:02 -08005102 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005103 out_dtype = DType.INT32
5104 elif ifm.dtype == DType.INT16:
5105 out_dtype = DType.INT48
5106 elif ifm.dtype == DType.FLOAT:
5107 out_dtype = DType.FLOAT
5108 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005109 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005110
Kevin Cheng550ccc52021-03-03 11:21:43 -08005111 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005112
5113 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -07005114 def conv3dOp(ser, ifm, filter, strides, padding, dilations):
5115
5116 # IFM: NDHWC
5117 # Filter: ODHWI
5118 # OFM: NDHWC
5119
5120 d = (
5121 ifm.shape[1]
5122 - filter.shape[1]
5123 - (filter.shape[1] - 1) * (dilations[0] - 1)
5124 + padding[0]
5125 + padding[1]
5126 ) // strides[0] + 1
5127
5128 h = (
5129 ifm.shape[2]
5130 - filter.shape[2]
5131 - (filter.shape[2] - 1) * (dilations[1] - 1)
5132 + padding[2]
5133 + padding[3]
5134 ) // strides[1] + 1
5135
5136 w = (
5137 ifm.shape[3]
5138 - filter.shape[3]
5139 - (filter.shape[3] - 1) * (dilations[2] - 1)
5140 + padding[4]
5141 + padding[5]
5142 ) // strides[2] + 1
5143
5144 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5145
5146 if ifm.dtype == DType.INT8:
5147 out_dtype = DType.INT32
5148 elif ifm.dtype == DType.INT16:
5149 out_dtype = DType.INT48
5150 elif ifm.dtype == DType.FLOAT:
5151 out_dtype = DType.FLOAT
5152 else:
5153 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
5154
5155 return ser.addOutput(ofm_shape, out_dtype)
5156
5157 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -07005158 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
5159 # IFM: NHWC
5160 # Filter: HWCM
5161 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08005162 h = (
5163 ifm.shape[1]
5164 - filter.shape[0]
5165 - (filter.shape[0] - 1) * (dilations[0] - 1)
5166 + padding[0]
5167 + padding[1]
5168 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005169
Kevin Cheng550ccc52021-03-03 11:21:43 -08005170 w = (
5171 ifm.shape[2]
5172 - filter.shape[1]
5173 - (filter.shape[1] - 1) * (dilations[1] - 1)
5174 + padding[2]
5175 + padding[3]
5176 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005177
Eric Kunzee5e26762020-10-13 16:11:07 -07005178 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5179
Kevin Cheng3a478572021-01-22 17:21:02 -08005180 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005181 out_dtype = DType.INT32
5182 elif ifm.dtype == DType.INT16:
5183 out_dtype = DType.INT48
5184 elif ifm.dtype == DType.FLOAT:
5185 out_dtype = DType.FLOAT
5186 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005187 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005188
Kevin Cheng550ccc52021-03-03 11:21:43 -08005189 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005190
5191 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005192 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005193 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005194 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005195 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005196 h = 1
5197 w = 1
5198 else:
5199 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
5200 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
5201
5202 if error_name == ErrorIf.PoolingOutputShapeMismatch:
5203 choices = [1, 2, 3, 4, 5]
5204 h = h + rng.choice(choices)
5205 w = w + rng.choice(choices)
Eric Kunzee5e26762020-10-13 16:11:07 -07005206
Eric Kunzee5e26762020-10-13 16:11:07 -07005207 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005208
5209 if error_name == ErrorIf.WrongOutputType:
5210 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5211 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5212 outputDType = rng.choice(wrong_dtypes)
5213 else:
5214 outputDType = ifm.dtype
5215
5216 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005217
5218 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005219 def fullyConnectedOp(ser, rng, input, filter, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005220 # input: N, IC
5221 # filter: OC, IC
5222 # output: N, OC
5223
5224 output_shape = [input.shape[0], filter.shape[0]]
5225
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005226 if error_name == ErrorIf.WrongOutputType:
5227 if input.dtype == DType.INT8:
5228 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
5229 elif input.dtype == DType.INT16:
5230 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
5231 elif input.dtype == DType.FLOAT:
5232 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
5233 out_dtype = rng.choice(a=incorrect_types)
5234 elif input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005235 out_dtype = DType.INT32
5236 elif input.dtype == DType.INT16:
5237 out_dtype = DType.INT48
5238 elif input.dtype == DType.FLOAT:
5239 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005240 elif error_name == ErrorIf.WrongInputType:
5241 # Pick some potentially correct output dtype if input type is incorrect
5242 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005243 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005244 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005245
Kevin Cheng550ccc52021-03-03 11:21:43 -08005246 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005247
5248 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005249 def matmulOp(ser, rng, a, b, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005250 # a: N, H, C
5251 # b: N, C, W
5252 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005253
Kevin Cheng2d60f002021-06-09 14:18:32 -07005254 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005255
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005256 if error_name == ErrorIf.WrongOutputType:
5257 if a.dtype == DType.INT8:
5258 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
5259 elif a.dtype == DType.INT16:
5260 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
5261 elif a.dtype == DType.FLOAT:
5262 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
5263 out_dtype = rng.choice(a=incorrect_types)
5264 elif a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005265 out_dtype = DType.INT32
5266 elif a.dtype == DType.INT16:
5267 out_dtype = DType.INT48
5268 elif a.dtype == DType.FLOAT:
5269 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005270 elif error_name == ErrorIf.WrongInputType:
5271 # Pick some potentially correct output dtype if input type is incorrect
5272 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005273 else:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005274 raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005275
Kevin Cheng550ccc52021-03-03 11:21:43 -08005276 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005277
5278 @staticmethod
Matthew Haddon818ab902021-07-27 09:12:49 +01005279 def concatOp(ser, axis, *a):
5280 input1 = a[0]
5281 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005282
Matthew Haddon818ab902021-07-27 09:12:49 +01005283 output_shape = input1.shape.copy()
Eric Kunzee5e26762020-10-13 16:11:07 -07005284
Matthew Haddon818ab902021-07-27 09:12:49 +01005285 output_shape[axis] = input1.shape[axis]
5286
5287 for tensor in remaining_inputs:
5288 output_shape[axis] += tensor.shape[axis]
5289
5290 return ser.addOutput(output_shape, input1.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005291
5292 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005293 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005294
5295 output_shape = a.shape.copy()
5296
5297 for i in range(len(output_shape)):
5298 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5299
Matthew Haddone807aae2021-10-11 18:12:58 +01005300 # Fix negative output shape if error_if test causes it
5301 if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
5302 output_shape = [i if i >= 1 else 1 for i in output_shape]
5303
5304 if error_name == ErrorIf.WrongOutputType:
5305 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5306 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5307 outputDType = rng.choice(wrong_dtypes)
5308 else:
5309 outputDType = a.dtype
5310
5311 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005312
5313 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005314 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005315 output_shape = shape.copy()
5316
5317 totalElements = 1
5318 for i in a.shape:
5319 totalElements *= i
5320
5321 # If there are any -1 elements, figure out what that dimension must be
5322 totalOutputElements = 1
5323 for i in output_shape:
5324 if i != -1:
5325 totalOutputElements *= i
5326
5327 # And fill it in
5328 for i in range(len(output_shape)):
5329 if output_shape[i] == -1:
5330 output_shape[i] = totalElements // totalOutputElements
5331
Matthew Haddone807aae2021-10-11 18:12:58 +01005332 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5333 for i in range(len(output_shape)):
5334 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5335
5336 if error_name == ErrorIf.WrongOutputType:
5337 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5338 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5339 outputDType = rng.choice(wrong_dtypes)
5340 else:
5341 outputDType = a.dtype
5342
5343 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005344
5345 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005346 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005347
Matthew Haddone807aae2021-10-11 18:12:58 +01005348 if error_name == ErrorIf.WrongOutputType:
5349 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5350 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5351 outputDType = rng.choice(wrong_dtypes)
5352 else:
5353 outputDType = a.dtype
5354
5355 if error_name == ErrorIf.SizeOutputShapeMismatch:
5356 output_shape = size.copy()
5357 for index in range(len(output_shape)):
5358 if output_shape[index] <= 2:
5359 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5360 else:
5361 output_shape[index] = output_shape[index] + rng.choice([-2, -1, 1, 2])
5362 else:
5363 output_shape = size.copy()
5364
5365 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005366
5367 @staticmethod
5368 def tileOp(ser, a, multiples):
5369
5370 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005371 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005372
5373 for i in range(len(output_shape)):
5374 output_shape[i] = a.shape[i] * multiples[i]
5375
Kevin Cheng550ccc52021-03-03 11:21:43 -08005376 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005377
5378 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005379 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005380 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005381
Kevin Cheng550ccc52021-03-03 11:21:43 -08005382 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005383
Matthew Haddone807aae2021-10-11 18:12:58 +01005384 if error_name == ErrorIf.IndexOutsideBounds:
5385 for i in range(len(output_shape)):
5386 output_shape[i] = a.shape[0]
5387 else:
5388 for i in range(len(output_shape)):
5389 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005390
Matthew Haddone807aae2021-10-11 18:12:58 +01005391 if error_name == ErrorIf.WrongOutputType:
5392 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5393 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5394 outputDType = rng.choice(wrong_dtypes)
5395 else:
5396 outputDType = a.dtype
5397
5398 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005399
5400 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08005401 def gatherOp(ser, values, indices):
5402 assert len(values.shape) == 3
5403 assert len(indices.shape) == 2
5404 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005405
Kevin Cheng77d0f762020-11-24 10:26:32 -08005406 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5407
Kevin Cheng550ccc52021-03-03 11:21:43 -08005408 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005409
5410 @staticmethod
5411 def scatterOp(ser, values_in, indices, input):
5412 assert len(values_in.shape) == 3
5413 assert len(indices.shape) == 2
5414 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005415 assert values_in.shape[0] == indices.shape[0] # N
5416 assert input.shape[1] == indices.shape[1] # W
5417 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005418
5419 output_shape = values_in.shape
5420
Kevin Cheng550ccc52021-03-03 11:21:43 -08005421 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005422
5423 @staticmethod
Kevin Chengfe392ce2021-10-18 21:51:55 +00005424 def tableOp(ser, input):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005425 # Same shape as the input, but dtype dependent on table dtype
Kevin Chengfe392ce2021-10-18 21:51:55 +00005426 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
5427 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005428 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005429
5430 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005431 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005432 serializer,
5433 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005434 input,
5435 mode,
5436 stride,
5437 offset,
5438 shift,
5439 stride_fp,
5440 offset_fp,
5441 output_dims,
5442 input_dtype,
5443 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01005444 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08005445 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01005446 if error_name == ErrorIf.WrongRank:
5447 output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
5448 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005449 if error_name == ErrorIf.BatchMismatch:
5450 output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
5451 elif error_name == ErrorIf.ChannelMismatch:
5452 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
5453 else:
5454 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005455
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005456 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005457
5458 @staticmethod
5459 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005460 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005461
5462 @staticmethod
5463 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08005464 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07005465 out_dtype = DType.INT32
5466 elif ifm.dtype == DType.INT16:
5467 out_dtype = DType.INT48
5468 elif ifm.dtype == DType.FLOAT:
5469 out_dtype = DType.FLOAT
5470 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005471 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07005472
Kevin Cheng550ccc52021-03-03 11:21:43 -08005473 return ser.addOutput(output_shape, out_dtype)