blob: 22886d6482668ee4ae811369b0dd6cc90dd04f1d [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
Eric Kunzee5e26762020-10-13 16:11:07 -070017import numpy as np
18import argparse
19import sys
20import re
21import os
22import subprocess
23import shlex
24import json
25import glob
26import math
27import queue
28import threading
29import traceback
30import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010031import itertools
Matthew Haddon630c17c2021-10-14 15:05:41 +010032from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
Les Bell0e027d42021-11-09 14:42:14 +000048from tosa.DType import DType
49from tosa.Op import Op
50from tosa.ResizeMode import ResizeMode
Eric Kunzee5e26762020-10-13 16:11:07 -070051
Matthew Haddon630c17c2021-10-14 15:05:41 +010052
Les Bell0e027d42021-11-09 14:42:14 +000053def valueToName(item, value):
54 """Get the name of an attribute with the given value.
55
56 This convenience function is needed to print meaningful names for
57 the values of the tosa.Op.Op and tosa.DType.DType classes.
58 This would not be necessary if they were subclasses of Enum, or
59 IntEnum, which, sadly, they are not.
60
61 Args:
62 item: The class, or object, to find the value in
63 value: The value to find
64
65 Example, to get the name of a DType value:
66
67 name = valueToName(DType, DType.INT8) # returns 'INT8'
68 name = valueToName(DType, 4) # returns 'INT8'
69
70 Returns:
71 The name of the first attribute found with a matching value,
72
73 Raises:
74 ValueError if the value is not found
75 """
76 for attr in dir(item):
77 if getattr(item, attr) == value:
78 return attr
79 raise ValueError(f'value ({value}) not found')
80
81def allDTypes(*, excludes=None):
82 """Get a set of all DType values, optionally excluding some values.
83
84 This convenience function is needed to provide a sequence of DType values.
85 This would be much easier if DType was a subclass of Enum, or IntEnum,
86 as we could then iterate over the values directly, instead of using
87 dir() to find the attributes and then check if they are what we want.
88
89 Args:
90 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
91
92 Returns:
93 A set of DType values
94 """
95 excludes = () if not excludes else excludes
96 return {getattr(DType, t) for t in dir(DType)
97 if not callable(getattr(DType, t)) and not t.startswith('__')
98 and getattr(DType, t) not in excludes}
99
100def usableDTypes(*, excludes=None):
101 """Get a set of usable DType values, optionally excluding some values.
102
103 Excludes (DType.UNKNOWN, DType.UINT8) in addition to the excludes
104 specified by the caller, as the serializer lib does not support them.
105 If you wish to include 'UNKNOWN' or 'UINT8' use allDTypes instead.
106
107 Args:
108 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
109
110 Returns:
111 A set of DType values
112 """
113 omit = {DType.UNKNOWN, DType.UINT8}
114 omit.update(excludes if excludes else ())
115 return allDTypes(excludes=omit)
116
Matthew Haddon630c17c2021-10-14 15:05:41 +0100117def product(shape):
118 value = 1
119 for n in shape:
120 value *= n
121 return value
122
Les Bell0e027d42021-11-09 14:42:14 +0000123
Eric Kunzee5e26762020-10-13 16:11:07 -0700124class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800125 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
126
Eric Kunzee5e26762020-10-13 16:11:07 -0700127 def __init__(self):
128 pass
129
130 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100131 def getQinfo(testGen, dtype, error_name=None):
132
Les Bell30e46802021-07-23 09:43:31 +0100133 if dtype == DType.INT8:
134 return testGen.randInt(-128, 128)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100135 elif dtype == DType.UINT8:
Les Bell30e46802021-07-23 09:43:31 +0100136 return testGen.randInt(0, 256)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100137 elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.WeightZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100138 zero_point = testGen.randInt(-128, 128)
139 if zero_point == 0:
140 zero_point = 1
141 return zero_point
Les Bell30e46802021-07-23 09:43:31 +0100142 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700143
144 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100145 def qgUnary(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700146 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100147 if error_name == ErrorIf.InputZeroPointNotZero:
148 qinfo.UnaryQuantInfo(
149 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype)
150 )
151 elif error_name == ErrorIf.OutputZeroPointNotZero:
152 qinfo.UnaryQuantInfo(
153 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype, error_name)
154 )
155 else:
156 qinfo.UnaryQuantInfo(
157 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
158 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700159 return qinfo
160
161 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100162 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700163 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +0100164 if isinstance(dtype_or_dtypeList, list):
165 # a list of [input, weights, accumulator] dtypes
166 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -0700167 else:
Les Bell30e46802021-07-23 09:43:31 +0100168 # an int, [input, weights, accumulator] dtypes are the same
169 dtypeList = [dtype_or_dtypeList] * 3
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100170
171 if error_name == ErrorIf.InputZeroPointNotZero:
172 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
173 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
174 elif error_name == ErrorIf.WeightZeroPointNotZero:
175 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
176 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
177 else:
178 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
179 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
180
Les Bell30e46802021-07-23 09:43:31 +0100181 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -0700182 return qinfo
183
184 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100185 def qgMatmul(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700186 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100187 if error_name == ErrorIf.InputZeroPointNotZero:
188 qinfo.MatMulQuantInfo(
189 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype, error_name)
Kevin Chengacb550f2021-06-29 15:32:19 -0700190 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100191 else:
192 qinfo.MatMulQuantInfo(
193 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
194 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700195 return qinfo
196
197 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100198 def qgPad(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700199 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone807aae2021-10-11 18:12:58 +0100200 if error_name == ErrorIf.InputZeroPointNotZero:
201 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name))
202 else:
203 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700204 return qinfo
205
206 @staticmethod
207 def computeMultiplierAndShift(scaleFp, scale32):
208 # Derived from computeMultiplierAndShiftTosaScale32
209 # Provide a floating-point scaling factor and the scale32 parameter
210 # to compute the multiplier and shift
211
212 if scale32:
213 scaleBits = 31
214 else:
215 scaleBits = 15
216
217 m, shift = math.frexp(scaleFp)
218
219 if scaleFp < 0.0:
220 m = -m
221
222 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800223 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700224
225 if multiplier == (1 << scaleBits):
226 multiplier = multiplier // 2
227 shift = shift + 1
228
229 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100230 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
231
232 # Adjust multiplier such that shift is in allowed value range.
233 if shift == 0:
234 multiplier = multiplier // 4
235 shift = shift + 2
236 elif shift == 1:
237 multiplier = multiplier // 2
238 shift = shift + 1
239 elif shift == 63:
240 multiplier = multiplier * 2
241 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700242
Kevin Cheng550ccc52021-03-03 11:21:43 -0800243 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100244 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700245
246 return multiplier, shift
247
248
Kevin Cheng550ccc52021-03-03 11:21:43 -0800249class TosaTensorGen:
250 """Tensor generators create a shape list for the placeholder and const tensor
251 data operands for the operator. The actual random data is generated separately for each test."""
252
Eric Kunzee5e26762020-10-13 16:11:07 -0700253 def __init__(self):
254 pass
255
256 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100257 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800258 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700259 shape = testGen.makeShape(rank)
260
Matthew Haddon630c17c2021-10-14 15:05:41 +0100261 # Constrict the overall size of the shape when creating ERROR_IF tests
262 if error_name:
263 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
Matthew Haddonc2025212021-10-08 21:21:05 +0100264
Eric Kunzee5e26762020-10-13 16:11:07 -0700265 shape_list = []
266 for i in range(pl + const):
267 shape_list.append(shape.copy())
268
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100269 if error_name == ErrorIf.RankMismatch:
270 if rank == 1 and i != 1:
271 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
272 elif i != 1:
273 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
274
Eric Kunzee5e26762020-10-13 16:11:07 -0700275 return shape_list
276
277 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100278 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800279 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700280
Matthew Haddon848efb42021-09-09 12:30:53 +0100281 if error_name != ErrorIf.WrongRank:
282 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700283
284 shape = testGen.makeShape(rank)
285
286 # Constrict the batch size?
287 if testGen.args.max_batch_size:
288 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100289
Matthew Haddon630c17c2021-10-14 15:05:41 +0100290 # Constrict the overall size of the shape when creating ERROR_IF tests
Jeremy Johnson27cf5432021-11-16 11:12:17 +0000291 if error_name and error_name != ErrorIf.MaxDimExceeded:
Matthew Haddon630c17c2021-10-14 15:05:41 +0100292 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
Eric Kunzee5e26762020-10-13 16:11:07 -0700293
294 shape_list = []
295 for i in range(pl + const):
296 shape_list.append(shape.copy())
297
298 return shape_list
299
300 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100301 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800302 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800303
Kevin Cheng550ccc52021-03-03 11:21:43 -0800304 assert pl == 2
305 assert const == 0
Jeremy Johnson3ca02a72021-11-18 12:18:39 +0000306 if error_name != ErrorIf.WrongRank:
307 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800308
309 values_in_shape = testGen.makeShape(rank)
310
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100311 # ignore max batch size if target shape is set
312 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800313 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
314
Kevin Cheng550ccc52021-03-03 11:21:43 -0800315 W = testGen.randInt(
316 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
317 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100318 # Constrict W if one dimension is too large to keep tensor size reasonable
319 if max(values_in_shape) > 5000:
320 W = testGen.randInt(0, 16)
321
Kevin Cheng77d0f762020-11-24 10:26:32 -0800322 input_shape = [values_in_shape[0], W, values_in_shape[2]]
323
324 shape_list = []
325 shape_list.append(values_in_shape.copy())
326 shape_list.append(input_shape.copy())
327
328 return shape_list
329
330 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100331 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700332 shape = testGen.makeShape(rank)
333
Kevin Cheng550ccc52021-03-03 11:21:43 -0800334 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700335
336 shape_list = []
337
338 # Choose one of the inputs to broadcast
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +0000339 # Note: Simplifies OutputShaper code if we don't change first shape for errors
340 bcast_idx = testGen.randInt(0 if error_name == None else 1, pl + const)
Eric Kunzee5e26762020-10-13 16:11:07 -0700341 for i in range(pl + const):
342 shape_bcast = shape.copy()
343
344 # If the chosen input, pick a random index to broadcast
345 if i == bcast_idx:
346 fuzz_idx = testGen.randInt(0, rank)
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +0000347 if error_name == ErrorIf.DimensionMismatch:
348 shape_bcast[fuzz_idx] += 1
349 elif error_name == ErrorIf.RankMismatch:
350 # Add one rank to the shape (or more for rank of 1)
351 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
352 shape_bcast = np.concatenate((shape_bcast, testGen.makeShape(extra_ranks)))
353 if rank != 1:
354 # Either keep the extra rank, or remove it
355 new_len = testGen.rng.choice([-2, len(shape_bcast)])
356 shape_bcast = shape_bcast[:new_len]
357 else:
358 shape_bcast[fuzz_idx] = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700359
360 shape_list.append(shape_bcast)
361
362 return shape_list
363
364 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100365 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800366 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700367
Les Bell0e027d42021-11-09 14:42:14 +0000368 if error_name != ErrorIf.WrongRank:
369 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700370
371 # IFM dimensions are NHWC
372 ifm_shape = testGen.makeShape(rank)
373
374 # Constrict the batch size?
375 if testGen.args.max_batch_size:
376 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
377
Les Bell0e027d42021-11-09 14:42:14 +0000378 # Constrict the overall size of the shape when creating ERROR_IF tests
379 if error_name:
380 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
381
Eric Kunzee5e26762020-10-13 16:11:07 -0700382 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800383 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700384
385 # Generate a random OFM depth
386 ofm_depth = testGen.makeShape(1)[0]
387
388 # The filter dimensions are OHWI
389 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
390
391 # The bias is OC
392 bias_shape = np.asarray([ofm_depth])
393
394 return [ifm_shape, filter_shape, bias_shape]
395
396 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100397 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700398 pl, const = op["operands"]
399
Les Bell0e027d42021-11-09 14:42:14 +0000400 if error_name != ErrorIf.WrongRank:
401 assert rank == 5
Kevin Cheng1533b852021-09-01 12:51:58 -0700402
403 # IFM dimensions are NDHWC
404 ifm_shape = testGen.makeShape(rank)
405
406 # Constrict the batch size?
407 if testGen.args.max_batch_size:
408 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
409
Les Bell0e027d42021-11-09 14:42:14 +0000410 # Constrict the overall size of the shape when creating ERROR_IF tests
411 if error_name:
412 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
413
Kevin Cheng1533b852021-09-01 12:51:58 -0700414 # Get the filter depth/height/width from the operator parameters
415 filter_dhw = op["filter"]
416
417 # Generate a random OFM channel
418 ofm_channel = testGen.makeShape(1)[0]
419
420 # The filter dimensions are ODHWI
421 filter_shape = np.asarray(
422 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
423 )
424
425 # The bias is OC
426 bias_shape = np.asarray([ofm_channel])
427
428 return [ifm_shape, filter_shape, bias_shape]
429
430 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100431 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800432 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700433
Les Bell0e027d42021-11-09 14:42:14 +0000434 if error_name != ErrorIf.WrongRank:
435 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700436
437 # IFM dimensions are NHWC
438 ifm_shape = testGen.makeShape(rank)
439
440 # Constrict the batch size?
441 if testGen.args.max_batch_size:
442 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
443
Les Bell0e027d42021-11-09 14:42:14 +0000444 # Constrict the overall size of the shape when creating ERROR_IF tests
445 if error_name:
446 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
447
Eric Kunzee5e26762020-10-13 16:11:07 -0700448 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800449 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700450
451 # Generate a random OFM depth
452 ofm_depth = testGen.makeShape(1)[0]
453
454 # The filter dimensions are OHWI
455 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
456
Kevin Cheng989cb052021-04-28 16:29:44 -0700457 # The bias is OC
458 bias_shape = np.asarray([ofm_depth])
459
460 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700461
462 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100463 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800464 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700465
Les Bell0e027d42021-11-09 14:42:14 +0000466 if error_name != ErrorIf.WrongRank:
467 assert rank == 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800468 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700469
470 # IFM dimensions are NHWC
471 ifm_shape = testGen.makeShape(rank)
472
473 # Constrict the batch size?
474 if testGen.args.max_batch_size:
475 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
476
Les Bell0e027d42021-11-09 14:42:14 +0000477 # Constrict the overall size of the shape when creating ERROR_IF tests
478 if error_name:
479 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
480
Eric Kunzee5e26762020-10-13 16:11:07 -0700481 # Get the filter height/width from the operator parameters
482 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800483 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700484
485 # Generate a random OFM depth, but don't let it get too big because
486 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800487 filter_m = (
488 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
489 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700490
491 # The filter dimensions are HWCM
492 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
493
494 # The bias is M * C
495 bias_shape = np.asarray([ifm_shape[3] * filter_m])
496
497 return [ifm_shape, filter_shape, bias_shape]
498
499 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100500 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800501 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700502
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100503 if error_name != ErrorIf.WrongRank:
504 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700505
506 input_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100507
Matthew Haddon630c17c2021-10-14 15:05:41 +0100508 # Constrict the overall size of the shape when creating ERROR_IF tests
509 if error_name:
Les Bell0e027d42021-11-09 14:42:14 +0000510 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100511
Kevin Chengacb550f2021-06-29 15:32:19 -0700512 filter_oc = testGen.rng.integers(
513 low=testGen.args.tensor_shape_range[0],
514 high=testGen.args.tensor_shape_range[1],
515 size=1,
516 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700517 filter_shape = np.asarray([filter_oc, input_shape[1]])
518
519 bias_shape = np.asarray([filter_oc])
520
521 return [input_shape, filter_shape, bias_shape]
522
523 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100524 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800525 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700526
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100527 if error_name != ErrorIf.WrongRank:
528 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800529 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700530
531 a_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100532
Matthew Haddon630c17c2021-10-14 15:05:41 +0100533 # Constrict the overall size of the shape when creating ERROR_IF tests
534 if error_name:
Les Bell0e027d42021-11-09 14:42:14 +0000535 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100536
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100537 # Get a random number for b_oc even if target shape is defined
538 b_oc = np.int32(
539 testGen.rng.integers(
540 low=testGen.args.tensor_shape_range[0],
541 high=testGen.args.tensor_shape_range[1],
542 size=1,
543 )
544 )[0]
545 # If N or H is large let b_oc be 1 to reduce output tensor size
546 if max(a_shape) > 1000:
547 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700548
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100549 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700550 return [a_shape, b_shape]
551
Matthew Haddon818ab902021-07-27 09:12:49 +0100552 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100553 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100554 pl, const = opName["operands"]
555 shape = testGen.makeShape(rank)
556
557 # Create extra tensors to concat.
558 # Take into account value of pl when getting maximum number of concats
559 num_tensors = testGen.randInt(0, 4)
560 shape_list = []
561 for i in range(pl + const + num_tensors):
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100562 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
563 remove = testGen.rng.choice([True, False])
564 wrongShape = shape.copy()
565
566 if remove and len(shape) > 1:
567 wrongShape = wrongShape[1:]
568 else:
569 wrongShape = list(wrongShape)
570 wrongShape.append(testGen.rng.integers(1, 10))
571
572 shape_list.append(wrongShape)
573 else:
574 shape_list.append(shape.copy())
Matthew Haddon818ab902021-07-27 09:12:49 +0100575
576 return shape_list
577
578 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100579 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100580 if error_name in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ConcatInputRankMismatch]:
581 return shapeList
582
Matthew Haddon818ab902021-07-27 09:12:49 +0100583 # Split concat shape along axis to allow for multiple const inputs
584 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100585 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100586 # If axis can't be split we still need to invalidate other dimensions
587 if error_name == ErrorIf.ConcatInputDimMismatch:
588 for shape in shapeList[1:]:
589 # Negative test shapeLists are created individually for each test,
590 # so no need to copy the shape before altering it.
591 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
Matthew Haddon818ab902021-07-27 09:12:49 +0100592 return shapeList
593
Jeremy Johnson960985a2021-10-06 10:58:14 +0100594 # Create copy of shape we are going to split (so we don't alter shapeList)
595 shape = shapeList[0].copy()
596 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100597 new_shapeList = [shape.copy()]
598 length_on_axis = shape[axis]
599 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700600 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100601 # Calculate split on axis and remaining value
602 split_shape_val = int(shape[axis] / 2)
603 remaining_length = remaining_length - split_shape_val
604
605 # Append new shape, and set remaining shape
606 shape[axis] = split_shape_val
607 new_shapeList.append(shape.copy())
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100608
609 # invalidate dimensions
610 if error_name == ErrorIf.ConcatInputDimMismatch:
611 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
612 else:
613 shape[axis] = remaining_length
614
Matthew Haddon818ab902021-07-27 09:12:49 +0100615 if i == len(shapeList) - 3:
616 new_shapeList.append(shape.copy())
617
618 return new_shapeList
619
620
Eric Kunzee5e26762020-10-13 16:11:07 -0700621class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800622 """Argument generators create exhaustive or random lists of attributes for operators that take
623 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
624 tuples where the descriptive_name is appended to the test name and the arglist is expanded
625 as arguments to the operator build function."""
626
Eric Kunzee5e26762020-10-13 16:11:07 -0700627 def __init__(self):
628 pass
629
630 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100631 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800632 """A trivial argument generator for operators that don't take any
633 non-tensor arguments"""
634 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700635
636 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100637 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800638 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700639 axes = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700640 shape = shapeList[0]
641
Matthew Haddond6ce7252021-09-29 15:35:44 +0100642 if error_name == ErrorIf.AxisSmallerZero:
643 small_axis = testGen.rng.integers(-5, 0)
644 axes.append(("axis{}".format(small_axis), [small_axis]))
645 elif error_name == ErrorIf.AxisLargerRank:
646 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
647 axes.append(("axis{}".format(large_axis), [large_axis]))
648 else:
649 for a in range(0, len(shape)):
650 axes.append(("axis{}".format(a), [a]))
651
Eric Kunzee5e26762020-10-13 16:11:07 -0700652 return axes
653
654 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100655 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700656 arg_list = []
657
658 ifm_shape = shapeList[0]
659 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100660 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
661 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700662
Les Bell7aa69f42021-09-20 10:44:07 +0100663 # Check the rank
664 rank = 5 if opName.startswith("conv3d") else 4
Les Bell0e027d42021-11-09 14:42:14 +0000665 if error_name != ErrorIf.WrongRank:
666 assert len(ifm_shape) == rank
667 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700668
Les Bell7aa69f42021-09-20 10:44:07 +0100669 # kernel rank omits batch and channels
670 k_rank = rank - 2
Les Bell0e027d42021-11-09 14:42:14 +0000671 assert len(k) == k_rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700672
Les Bell7aa69f42021-09-20 10:44:07 +0100673 # Generate comprehensive argument lists
Les Bell0e027d42021-11-09 14:42:14 +0000674 # - except for named errors, which use specific invalid value(s)
675 if error_name == ErrorIf.PadSmallerZero:
676 p_vals = [testGen.rng.choice(range(-5, 0))]
677 else:
678 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100679 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
Les Bell0e027d42021-11-09 14:42:14 +0000680 if error_name == ErrorIf.StrideSmallerOne:
681 # Can't use stride=0, as it is used to derive output shape, as a divisor
682 s_vals = [testGen.rng.choice(range(-5, 0))]
683 else:
684 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100685 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
Les Bell0e027d42021-11-09 14:42:14 +0000686 if error_name == ErrorIf.DilationSmallerOne:
687 d_vals = [testGen.rng.choice(range(-5, 1))]
688 else:
689 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100690 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700691
Les Bell0e027d42021-11-09 14:42:14 +0000692 if not error_name:
693 # add some oversize argument values
694 if max(ifm_shape) < 64:
695 bigPadding = 9
696 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
697 bigStride = 8
698 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
699 bigDilation = 7
700 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100701
Les Bell0e027d42021-11-09 14:42:14 +0000702 # There are too many parameter combinations, so generate them sparsely,
703 # very sparse for negative tests
704 sparsity_factor = 2 if error_name else 100
705 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
706 # If there are only a small number of tests, just select them all
Les Bell7aa69f42021-09-20 10:44:07 +0100707 if sparsity < 13:
708 sparsity = 1
Les Bell0e027d42021-11-09 14:42:14 +0000709 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
Les Bell7aa69f42021-09-20 10:44:07 +0100710 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
711 sparsity += 1
Les Bell0e027d42021-11-09 14:42:14 +0000712
Les Bellf414b3c2021-09-06 11:29:46 +0100713 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100714 for s in sorted(list(strides)):
715 for p in sorted(list(paddings)):
716 for d in sorted(list(dilations)):
717 if (n % sparsity == 0
718 # padding must not exceed the kernel size ?
719 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
720 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
721 # the padded shape must exceed the kernel size
722 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
723 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
724 # the padded shape must exceed the dilation
725 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
726 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
727 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100728 arg_list.append(
729 (
730 "st{}_pad{}_dilat{}".format(
731 "".join([str(x) for x in s]),
732 "".join([str(x) for x in p]),
733 "".join([str(x) for x in d]),
734 ),
735 [s, p, d],
736 )
737 )
738 n += 1
739
Kevin Cheng1533b852021-09-01 12:51:58 -0700740 return arg_list
741
742 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100743 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700744 arg_list = []
745
746 ifm_shape = shapeList[0]
747 filter_shape = shapeList[1]
748
749 # Must be rank 4
Les Bell0e027d42021-11-09 14:42:14 +0000750 if error_name != ErrorIf.WrongRank:
751 assert len(ifm_shape) == 4
752 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700753
Les Bell7aa69f42021-09-20 10:44:07 +0100754 # Generate comprehensive argument lists
Les Bell0e027d42021-11-09 14:42:14 +0000755 # - except for named errors, which use specific invalid value(s)
756 if error_name == ErrorIf.PadSmallerZero:
757 p_vals = [testGen.rng.choice(range(-5, 0))]
758 else:
759 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100760 paddings = {x for x in itertools.product(*([p_vals] * 2))}
Les Bell0e027d42021-11-09 14:42:14 +0000761 if error_name == ErrorIf.StrideSmallerOne:
762 # Can't use stride=0, as it is used to derive output shape, as a divisor
763 s_vals = [testGen.rng.choice(range(-5, 0))]
764 else:
765 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100766 strides = {x for x in itertools.product(*([s_vals] * 2))}
Les Bell0e027d42021-11-09 14:42:14 +0000767 if error_name == ErrorIf.DilationSmallerOne:
768 d_vals = [testGen.rng.choice(range(-5, 1))]
769 else:
770 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100771 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700772
Les Bell0e027d42021-11-09 14:42:14 +0000773 if not error_name:
774 # add some oversize argument values
775 if max(ifm_shape) < 64:
776 bigPadding = 9
777 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
778 bigStride = 8
779 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
780 bigDilation = 7
781 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700782
Les Bell0e027d42021-11-09 14:42:14 +0000783 # There are too many parameter combinations, so generate them sparsely,
784 # very sparse for negative tests
785 sparsity_factor = 2 if error_name else 100
786 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
787 # If there are only a small number of tests, just select them all
Les Bell7aa69f42021-09-20 10:44:07 +0100788 if sparsity < 13:
789 sparsity = 1
Les Bell0e027d42021-11-09 14:42:14 +0000790 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
Les Bell7aa69f42021-09-20 10:44:07 +0100791 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
792 sparsity += 1
Les Bell0e027d42021-11-09 14:42:14 +0000793
Les Bell7aa69f42021-09-20 10:44:07 +0100794 n = 0
795 for s in sorted(list(strides)):
796 for p in sorted(list(paddings)):
797 for d in sorted(list(dilations)):
798 if n % sparsity == 0:
799 # Determine the output shape
800 oh = (
801 ifm_shape[1]
802 - filter_shape[1]
803 - (filter_shape[1] - 1) * (d[0] - 1)
804 + 2 * p[0]
805 ) // s[0] + 1
806 ow = (
807 ifm_shape[2]
808 - filter_shape[2]
809 - (filter_shape[2] - 1) * (d[1] - 1)
810 + 2 * p[1]
811 ) // s[1] + 1
812 os = [ifm_shape[0], oh, ow, filter_shape[0]]
813 arg_list.append(
814 (
815 "st{}_pad{}_dilat{}_os{}".format(
816 "".join([str(x) for x in s]),
817 "".join([str(x) for x in p]),
818 "".join([str(x) for x in d]),
819 "x".join([str(x) for x in os]),
820 ),
821 [s, p, d, os],
822 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800823 )
Les Bell7aa69f42021-09-20 10:44:07 +0100824 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700825
826 return arg_list
827
828 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100829 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700830 arg_list = []
831 rank = len(shapeList[0])
832
Les Bell7ffccce2021-07-28 15:37:02 +0100833 # Exhaustively test combinations of padding on each side of each dimension
834 # - the range of padding values is defined by pad_min and pad_max
835 # - for padding >9, the name format needs to be more distinctive
836 pad_min, pad_max = 0, 1
837 pad_values = [x for x in range(pad_min, pad_max + 1)]
Matthew Haddone807aae2021-10-11 18:12:58 +0100838 if error_name == ErrorIf.PadSmallerZero:
839 pad_values = [x for x in range(-2, 0)]
Les Bell7ffccce2021-07-28 15:37:02 +0100840 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
841 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700842
Kevin Chengfe392ce2021-10-18 21:51:55 +0000843 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
844 pad_const_int = testGen.getRandNumberDType(dtype)
845 pad_const_fp = 0
846 elif dtype == DType.FLOAT:
847 pad_const_int = 0
848 pad_const_fp = testGen.getRandNumberDType(dtype)
849 else:
850 return []
851
Les Bell7ffccce2021-07-28 15:37:02 +0100852 for paddings in shape_pad_values:
853 name = "pad"
854 for r in range(rank):
855 before, after = paddings[r]
856 name = f"{name}{before}{after}"
Kevin Chengfe392ce2021-10-18 21:51:55 +0000857 arg_list.append((name, [np.array(paddings), pad_const_int, pad_const_fp]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700858
859 return arg_list
860
861 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100862 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700863 arg_list = []
864
865 shape = shapeList[0]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100866 if error_name != ErrorIf.WrongRank:
867 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700868
Les Bell7aa69f42021-09-20 10:44:07 +0100869 # Generate comprehensive argument lists
870 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
871 paddings = {x for x in itertools.product(*([p_vals] * 4))}
872 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
873 strides = {x for x in itertools.product(*([s_vals] * 2))}
874 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
875 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700876
Les Bell7aa69f42021-09-20 10:44:07 +0100877 # add some oversize argument values
878 bigStride = 7
879 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
880 bigKernel = 6
881 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
882 if max(shape) < 64:
883 # padding must be less than the kernel size
884 bigPadding = bigKernel - 1
885 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700886
Les Bell0e027d42021-11-09 14:42:14 +0000887 # There are too many parameter combinations, so generate them sparsely,
888 # very sparse for negative tests
889 sparsity_factor = 2 if error_name else 500
890 sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
891
Les Bell7aa69f42021-09-20 10:44:07 +0100892 n = 0
893 for s in sorted(list(strides)):
894 for p in sorted(list(paddings)):
895 for k in sorted(list(kernels)):
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100896 if error_name in [ErrorIf.StrideSmallerOne, ErrorIf.KernelSmallerOne, ErrorIf.PadSmallerZero, ErrorIf.PadLargerEqualKernel]:
897 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(testGen, error_name, s, p, k)
898 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
899 arg_list.append(
900 (
901 "st{}_kern{}_pad{}".format(
902 "".join([str(x) for x in sNew]),
903 "".join([str(x) for x in kNew]),
904 "".join([str(x) for x in pNew]),
905 ),
906 [sNew, pNew, kNew],
907 )
908 )
909 elif (n % sparsity == 0
Les Bell7aa69f42021-09-20 10:44:07 +0100910 # padding must not exceed the kernel size
911 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
912 # the padded shape must exceed the kernel size
913 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
914 ):
915 arg_list.append(
916 (
917 "st{}_kern{}_pad{}".format(
918 "".join([str(x) for x in s]),
919 "".join([str(x) for x in k]),
920 "".join([str(x) for x in p]),
921 ),
922 [s, p, k],
923 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800924 )
Les Bell7aa69f42021-09-20 10:44:07 +0100925 n += 1
926
Eric Kunzee5e26762020-10-13 16:11:07 -0700927 return arg_list
928
929 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100930 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700931 arg_list = []
932
933 # Enumerate the output types here
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100934 if error_name == ErrorIf.WrongOutputType:
935 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
936 elif inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800937 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700938 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800939 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700940 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800941 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700942 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800943 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700944 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800945 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100946 elif error_name == ErrorIf.WrongInputType:
947 # Pick some potentially correct output type for incorrect input type
948 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700949 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800950 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700951
952 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800953 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700954
955 return arg_list
956
957 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100958 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700959 arg_list = []
960
961 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100962 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100963 if dtype in [DType.UINT8, DType.INT8] and error_name == ErrorIf.OutputZeroPointNotZero:
964 continue
965 if inDtype == DType.UINT8 and dtype != DType.INT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100966 # The only output dtype for UINT8 is INT8, skip all other combinations
967 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100968 if inDtype != DType.INT8 and dtype == DType.UINT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100969 # The only input dtype for UINT8 is INT8, skip all other combinations
970 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100971 if error_name == ErrorIf.WrongOutputType and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype):
972 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100973
Kevin Cheng550ccc52021-03-03 11:21:43 -0800974 for scale32 in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100975 if error_name == ErrorIf.ScaleTrue and scale32 == False:
976 continue
977 elif error_name == ErrorIf.ScaleNotTrue and scale32 == True:
978 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800979 for double_round in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100980 if error_name == ErrorIf.ScaleNotTrue and double_round == False:
981 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800982 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700983
Matthew Haddonc2025212021-10-08 21:21:05 +0100984 if inDtype == DType.INT48 and scale32 and error_name != ErrorIf.ScaleTrue:
Eric Kunzee5e26762020-10-13 16:11:07 -0700985 # Illegal condition. Must be scale32=False
986 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100987 if double_round and not scale32 and error_name != ErrorIf.ScaleNotTrue:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100988 # Illegal condition. ERROR_IF(!scale32 && double_round)
989 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700990
Kevin Cheng550ccc52021-03-03 11:21:43 -0800991 arg_list.append(
992 (
993 "out{}_sc{}_dr{}_pc{}".format(
994 DTypeNames[dtype],
995 int(scale32),
996 int(double_round),
997 int(per_channel),
998 ),
999 [dtype, scale32, double_round, per_channel],
1000 )
1001 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001002
1003 return arg_list
1004
Kevin Chengaee1fac2020-11-11 13:54:06 -08001005 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001006 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -08001007 arg_list = []
1008
1009 if dtype is DType.INT32:
1010 for p in range(testGen.args.num_rand_permutations):
1011
1012 shift = testGen.randInt(0, 32)
1013
Kevin Cheng550ccc52021-03-03 11:21:43 -08001014 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001015 else:
Matthew Haddon43e37192021-07-09 14:13:02 +01001016 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001017
1018 return arg_list
1019
1020 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001021 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -08001022 arg_list = []
1023
Kevin Cheng550ccc52021-03-03 11:21:43 -08001024 arg_list.append(("roundTrue", [True]))
1025 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001026
1027 return arg_list
1028
Eric Kunzee5e26762020-10-13 16:11:07 -07001029 # Helper function for reshape. Gets some factors of a larger number.
1030 @staticmethod
1031 def getFactors(val, start=1):
1032 factors = []
1033
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001034 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -07001035 if (val % i) == 0:
1036 factors.append(i)
1037
1038 return factors
1039
1040 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001041 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001042 arg_list = []
1043
1044 origShape = shapeList[0]
1045
1046 totalElements = 1
1047 for s in origShape:
1048 totalElements *= s
1049
1050 # This code is NOT fast. Fortunately, the numbers are fairly small.
1051 factors = TosaArgGen.getFactors(totalElements)
1052
1053 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +01001054 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001055 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -07001056 continue
1057
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001058 found = True
1059 # escape_counter breaks while loop if it continues on for too long
1060 escape_counter = 0
1061 while found:
1062 newShape = []
1063 # Generate newShape ensuring it isn't a duplicate
1064 remainingElements = totalElements
1065 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +01001066 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001067 # pick rank-1 factors
1068 newShape.append(shuffledFactors[0])
1069 remainingElements = remainingElements // shuffledFactors[0]
1070 shuffledFactors = testGen.rng.permutation(
1071 TosaArgGen.getFactors(remainingElements)
1072 )
1073 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -07001074
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001075 # Toss in a -1 sometimes
1076 minusOne = testGen.randInt(0, newRank * 4)
1077 if minusOne < newRank:
1078 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -07001079
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001080 # Check for duplicates
1081 found = False
1082 for name, other_shape in arg_list:
1083 if other_shape[0] == newShape:
1084 found = True
1085 break
1086
1087 escape_counter += 1
1088 if escape_counter >= 100:
1089 break
1090
1091 if not found:
1092 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001093
1094 return arg_list
1095
Eric Kunzee5e26762020-10-13 16:11:07 -07001096 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001097 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001098 arg_list = []
1099
1100 ifm_shape = shapeList[0]
1101
Matthew Haddone807aae2021-10-11 18:12:58 +01001102
1103 if error_name == ErrorIf.IndexOutsideBounds:
1104 incorrect_large_index = range(len(ifm_shape)+1, 2*len(ifm_shape)+1)
1105 incorrect_small_index = range(-len(ifm_shape), 0)
1106 permutations = [p for p in itertools.permutations(incorrect_large_index)]
1107 permutations.extend([p for p in itertools.permutations(incorrect_small_index)])
1108 elif error_name == ErrorIf.IndexUsedTwice:
1109 # Create list with a duplicated index
1110 perm_range = list(range(len(ifm_shape)))
1111 index_choice = testGen.rng.choice(range(len(perm_range)))
1112 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
1113 permutations = [p for p in itertools.permutations(perm_range)]
1114
1115
1116 else:
1117 # Get all permutations
1118 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -07001119
Jeremy Johnsona6185572021-06-21 15:55:35 +01001120 # Limit to possible permutations from shape dimension or argument setting
1121 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001122
Jeremy Johnsona6185572021-06-21 15:55:35 +01001123 # Get random permutation generator that uses all permutations
1124 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001125
Jeremy Johnsona6185572021-06-21 15:55:35 +01001126 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -07001127 arg_list = [
1128 ("perm{}".format(p), [random_permutations[p].tolist()])
1129 for p in range(limit)
1130 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07001131 return arg_list
1132
1133 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001134 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001135 arg_list = []
1136
1137 ifm_shape = shapeList[0]
1138 rank = len(ifm_shape)
1139
1140 for p in range(testGen.args.num_rand_permutations):
Matthew Haddone807aae2021-10-11 18:12:58 +01001141 start = []
Eric Kunzee5e26762020-10-13 16:11:07 -07001142 size = []
1143
Kevin Cheng550ccc52021-03-03 11:21:43 -08001144 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -07001145
1146 for i in range(rank):
1147 if ifm_shape[i] > 1:
Matthew Haddone807aae2021-10-11 18:12:58 +01001148 start.append(testGen.randInt(0, ifm_shape[i]))
1149 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001150
1151 # Invalid slice size?
1152 if size[i] == 0:
1153 valid = False
1154 else:
Matthew Haddone807aae2021-10-11 18:12:58 +01001155 start.append(0)
Eric Kunzee5e26762020-10-13 16:11:07 -07001156 size.append(1)
1157
1158 if valid:
Matthew Haddone807aae2021-10-11 18:12:58 +01001159 # If ERROR_IF test required then incorrect start, size will be returned
1160 start, size = TosaErrorIfArgGen.eiSliceErrorIf(testGen, error_name, ifm_shape, start, size)
1161 arg_list.append(("perm{}".format(p), [start, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001162 return arg_list
1163
1164 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001165 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001166 arg_list = []
1167
1168 ifm_shape = shapeList[0]
1169 rank = len(ifm_shape)
1170
1171 for p in range(testGen.args.num_rand_permutations):
1172
1173 # Pick a few random, but small multiple values
1174 # because otherwise this has a tendency to generate
1175 # enormous tensors
1176 multiples = []
1177 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +01001178 if ifm_shape[i] > 1000:
1179 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
1180 multiples.append(1)
1181 elif max(ifm_shape) > 1000:
1182 multiples.append(2)
1183 else:
1184 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001185 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001186
1187 return arg_list
1188
1189 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001190 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001191 arg_list = []
1192
1193 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +01001194 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001195
1196 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +01001197 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001198 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +01001199 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001200 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +01001201 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001202 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +01001203 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001204 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001205 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001206 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +01001207 elif error_name == ErrorIf.WrongInputType:
1208 # If an incorrect input type is used then we set a 'correct'
1209 # output type to avoid other errors
1210 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -07001211 else:
1212 continue
1213
1214 for outputDType in outputDTypeList:
1215 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -07001216 # Randomly generate legal output dimensions and shift
1217 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +01001218 # A output_dim of 1 will cause offset to exceed allowed range
1219 # so minimum value 2 produced below
1220 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
1221 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
1222 output_dims[0] += 1
1223 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
1224 output_dims[1] += 1
1225
Kevin Cheng77d0f762020-11-24 10:26:32 -08001226 in_center_h = (ifm_shape[1] - 1) / 2.0
1227 in_center_w = (ifm_shape[2] - 1) / 2.0
1228 out_center_h = (output_dims[0] - 1) / 2.0
1229 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -07001230
Kevin Cheng77d0f762020-11-24 10:26:32 -08001231 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
1232 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
1233 fp_offset_y = in_center_h - fp_stride_y * out_center_h
1234 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -07001235
Kevin Cheng77d0f762020-11-24 10:26:32 -08001236 if outputDType == DType.FLOAT:
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001237 float_op = True
1238 arg_str = "mode{}_shift{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}"
Kevin Cheng77d0f762020-11-24 10:26:32 -08001239 shift = 0
1240 stride = [0, 0]
1241 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001242 stride_fp = [fp_stride_y, fp_stride_x]
1243 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +01001244
Kevin Cheng77d0f762020-11-24 10:26:32 -08001245 else:
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001246 float_op = False
1247 arg_str = "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}"
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001248 shift = testGen.randInt(1,12)
1249 # Now search for a shift value (1 to 11) that will produce
1250 # a valid and predictable resize operation
1251 count = 0
1252 while (count < 12):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001253 unit = float(1 << shift)
1254 stride_y = int(round(fp_stride_y * unit))
1255 stride_x = int(round(fp_stride_x * unit))
1256 offset_y = int(round(fp_offset_y * unit))
1257 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001258
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001259 if (
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001260 stride_y <= 0
1261 or stride_x <= 0
1262 or stride_y >= (16 << shift)
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001263 or stride_x >= (16 << shift)
1264 or offset_y >= (16 << shift)
1265 or offset_x >= (16 << shift)
1266 or offset_y <= (-16 << shift)
1267 or offset_x <= (-16 << shift)
1268 ):
1269 # Change the shift value and check again
1270 count += 1
1271 shift = (shift % 11) + 1
1272 continue
1273
1274 def RESIZE_REQUIRE_CALC(length_in, length_out, stride, offset, shift):
1275 # Perform the pseudo loop to look for out of bounds
1276 for pos in range(0,length_out):
1277 a = pos * stride + offset
1278 ia = a >> shift
1279 ia0 = max(ia, 0)
1280 ia1 = min(ia+1, length_in-1)
1281 if ia0 > ia1:
1282 # Found a problem value
1283 break
1284 return ia0, ia1
1285
1286 iy0, iy1 = RESIZE_REQUIRE_CALC(ifm_shape[1], output_dims[0], stride_y, offset_y, shift)
1287 ix0, ix1 = RESIZE_REQUIRE_CALC(ifm_shape[2], output_dims[1], stride_x, offset_x, shift)
1288 if ix0 > ix1 or iy0 > iy1:
1289 # Change the shift value and check again
1290 count += 1
1291 shift = (shift % 11) + 1
1292 continue
1293 break
1294
1295 if count >= 12:
1296 # Couldn't find a good set of values for this test, skip it
1297 continue
1298
Kevin Cheng550ccc52021-03-03 11:21:43 -08001299 stride = [stride_y, stride_x]
1300 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001301
1302 stride_fp = [0.0, 0.0]
1303 offset_fp = [0.0, 0.0]
1304
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001305 # Common for all data types
1306 if error_name is not None:
1307 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
1308 testGen,
1309 error_name,
1310 mode,
1311 dtype,
1312 shapeList,
1313 outputDType,
1314 shift,
1315 stride,
1316 stride_fp,
1317 offset,
1318 offset_fp
Kevin Cheng550ccc52021-03-03 11:21:43 -08001319 )
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001320 else:
1321 outputDTypeNew = outputDType
1322
1323 arg_list.append(
1324 (
1325 arg_str.format(
1326 "N" if mode == ResizeMode.NEAREST else "B",
1327 shift,
1328 output_dims[0],
1329 output_dims[1],
1330 testGen.typeStr(outputDTypeNew),
1331 stride_fp[0] if float_op else stride[0],
1332 stride_fp[1] if float_op else stride[1],
1333 offset_fp[0] if float_op else offset[0],
1334 offset_fp[1] if float_op else offset[1]
1335 ),
1336 [
1337 mode,
1338 stride,
1339 offset,
1340 shift,
1341 stride_fp,
1342 offset_fp,
1343 output_dims,
1344 dtype,
1345 outputDTypeNew,
1346 ],
1347 )
1348 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001349
1350 return arg_list
1351
Kevin Chengfe392ce2021-10-18 21:51:55 +00001352 @staticmethod
1353 def agTable(testGen, opName, shapeList, dtype, error_name=None):
1354 arg_list = []
1355
1356 if dtype == DType.INT8:
1357 table = np.int32(
1358 testGen.rng.integers(low=-128, high=128, size=[256])
1359 ).tolist()
1360 else: # INT16
1361 table = np.int32(
1362 testGen.rng.integers(low=-32768, high=32768, size=[513])
1363 ).tolist()
1364
1365 arg_list.append(
1366 (
1367 "",
1368 [table],
1369 )
1370 )
1371 return arg_list
1372
Matthew Haddon1c00b712021-10-01 15:51:03 +01001373 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001374 # CondIf generates the condition values here.
1375 # Convert to tensors in the build function, along with the
1376 # then and else blocks
1377 arg_list = []
1378
1379 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001380 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001381
1382 return arg_list
1383
Matthew Haddon1c00b712021-10-01 15:51:03 +01001384 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001385 # While loop: 0 iterations, 1, more than 1
1386 arg_list = []
1387
1388 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001389 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001390
1391 return arg_list
1392
Matthew Haddone86fd342021-09-07 16:12:21 +01001393class TosaErrorIfArgGen:
1394
1395 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001396 def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
Matthew Haddone86fd342021-09-07 16:12:21 +01001397
1398 if outputDType == DType.FLOAT:
1399 if error_name == ErrorIf.StrideSmallerEqualZero:
1400 stride_fp = testGen.rng.random(size=[2]) - 2
1401 elif error_name == ErrorIf.ShiftNotZero:
1402 shift = testGen.rng.integers(1, 5)
1403 elif error_name == ErrorIf.StrideLargerDimension:
1404 shape = shapeList[0]
1405 transform_height = testGen.rng.choice([False, True])
1406 if transform_height:
1407 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1408 else:
1409 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1410 else:
1411 if error_name == ErrorIf.StrideSmallerEqualZero:
1412 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1413 elif error_name == ErrorIf.ShiftSmallerOne:
1414 shift = testGen.rng.integers(-3, 1)
1415 if shift <= 0:
1416 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1417 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1418 else:
1419 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1420 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1421 elif error_name == ErrorIf.ShiftLargerEleven:
1422 shift = np.int16(testGen.rng.integers(12, 15))
1423 elif error_name == ErrorIf.StrideLargerDimension:
1424 shape = shapeList[0]
1425 transform_height = testGen.rng.choice([False, True])
1426 if transform_height:
1427 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1428 else:
1429 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1430 elif error_name == ErrorIf.StrideLargerEqualMax:
1431 stride = [(16 << shift) + 1, (16 << shift) + 1]
1432 elif error_name == ErrorIf.OffsetLargerEqualMax:
1433 offset = [(16 << shift) + 1, (16 << shift) + 1]
1434 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1435 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1436
Matthew Haddon1c00b712021-10-01 15:51:03 +01001437
Matthew Haddon848efb42021-09-09 12:30:53 +01001438 if error_name == ErrorIf.WrongOutputType:
1439 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1440 incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1441 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1442 incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
1443 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1444 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
1445 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1446 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
1447 elif dtype == DType.FLOAT:
1448 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
1449 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001450
Matthew Haddon848efb42021-09-09 12:30:53 +01001451 return shift, stride, stride_fp, offset, offset_fp, outputDType
1452
Matthew Haddone807aae2021-10-11 18:12:58 +01001453
Matthew Haddon848efb42021-09-09 12:30:53 +01001454 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001455 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
1456 if (error_name == ErrorIf.StrideSmallerOne
1457 # padding must not exceed the kernel size
1458 and pad[0] < kernel[0] and pad[1] < kernel[0] and pad[2] < kernel[1] and pad[3] < kernel[1]):
1459 wrongStride = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1460 return wrongStride, pad, kernel
1461 elif error_name == ErrorIf.PadSmallerZero:
1462 wrongPad = (testGen.rng.choice([-1, -2, -3]),
1463 testGen.rng.choice([-1, -2, -3]),
1464 testGen.rng.choice([-1, -2, -3]),
1465 testGen.rng.choice([-1, -2, -3]))
1466 return stride, wrongPad, kernel
1467 elif error_name == ErrorIf.KernelSmallerOne:
1468 wrongKernel = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1469 return stride, pad, wrongKernel
1470 elif error_name == ErrorIf.PadLargerEqualKernel:
1471 wrongPad = (testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1472 testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1473 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]),
1474 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]))
1475 return stride, wrongPad, kernel
1476 else:
1477 return None, None, None
1478
Matthew Haddone807aae2021-10-11 18:12:58 +01001479
Matthew Haddonc2025212021-10-08 21:21:05 +01001480 @staticmethod
1481 def eiRescaleWrongOutputType(input_dtype, output_dtype):
1482 if input_dtype == DType.INT8:
1483 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1484 return True
1485 if input_dtype in [DType.INT16, DType.INT32]:
1486 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1487 return True
1488 elif input_dtype == DType.INT48:
1489 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1490 return True
1491 elif input_dtype == DType.UINT8:
1492 if output_dtype != DType.INT8:
1493 return True
1494 return False
1495
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001496
1497 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001498 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1499 # Mess up input/output tensors for ERROR_IF checks
1500 if error_name == "WrongInputList":
1501 add_input = testGen.rng.choice([True, False])
1502 if add_input:
1503 input_list.append('eiDummyInput')
1504 else:
1505 input_list = input_list[:-1]
Les Bell0e027d42021-11-09 14:42:14 +00001506 elif error_name == "WrongOutputList":
Matthew Haddon848efb42021-09-09 12:30:53 +01001507 add_output = testGen.rng.choice([True, False])
1508 if add_output:
1509 output_list.append('eiDummyOutput')
1510 else:
1511 output_list = []
1512 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001513
Matthew Haddonc2025212021-10-08 21:21:05 +01001514 @staticmethod
Matthew Haddon630c17c2021-10-14 15:05:41 +01001515 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
1516 """Restrict the dimensions and overall size of a shape to max_dim and max_items."""
1517 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
1518 while product(new_shape) > max_items:
1519 new_shape = [max(d - 1, 1) for d in new_shape]
1520 return new_shape
Matthew Haddone807aae2021-10-11 18:12:58 +01001521
1522 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
1523 if error_name == ErrorIf.StartSmallerZero:
1524 newStart = []
1525 for i in range(len(input_shape)):
1526 newStart.append(testGen.rng.choice([-3, -2, -1]))
1527 return newStart, size
1528 elif error_name == ErrorIf.SizeSmallerEqualZero:
1529 newSize = []
1530 for i in range(len(input_shape)):
1531 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
1532 return start, newSize
1533 elif error_name == ErrorIf.StartSizeOutsideBounds:
1534 newStart, newSize = [], []
1535 for i in range(len(input_shape)):
1536 newStart.append(input_shape[i]-1)
1537 newSize.append(testGen.rng.choice([2, 3, 4]))
1538 return newStart, newSize
1539 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
1540 remove = testGen.rng.choice([True, False])
1541 if remove:
1542 newStart = start[1:]
1543 newSize = size[1:]
1544 else:
1545 newStart = start
1546 newStart.append(1)
1547 newSize = size
1548 newSize.append(1)
1549 return newStart, newSize
1550 else:
1551 return start, size
1552
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001553 @staticmethod
1554 def eiCastErrorIf(testGen, input_dtype):
1555 if input_dtype in [DType.BOOL, DType.FLOAT]:
1556 outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
1557 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
1558 outputDType = [DType.INT48]
1559 else:
1560 assert True, f"input_dtype ({input_dtype}) not supported"
1561 return outputDType
1562
1563
Matthew Haddone86fd342021-09-07 16:12:21 +01001564class TosaErrorValidator:
1565
Matthew Haddon848efb42021-09-09 12:30:53 +01001566 @staticmethod
1567 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
1568 # Check ERROR_IF statements
Matthew Haddon848efb42021-09-09 12:30:53 +01001569 for val_fcn in validator_fcns:
1570 val_result = val_fcn(True, **kwargs)
Matthew Haddon848efb42021-09-09 12:30:53 +01001571 validator_name = val_result['error_name']
1572 error_result = val_result['error_result']
1573 error_reason = val_result['error_reason']
1574
Les Bell0e027d42021-11-09 14:42:14 +00001575 # expect an error IFF the error_name and validator_name match
1576 expected_result = error_result == (error_name == validator_name)
1577
1578 if expected_result and error_result:
1579 serializer.setExpectedReturnCode(2, error_reason)
1580 elif error_result: # and not expected_result
1581 print(f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
1582 f" Expected: {error_name}, Got: {validator_name}")
1583 elif not expected_result: # and not error_result
1584 print(f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
1585 f" Expected: {error_name}")
1586
1587 if not expected_result:
1588 for k, v in sorted(kwargs.items()):
1589 if k != 'op':
1590 if k.endswith('dtype'):
1591 v = valueToName(DType, v)
1592 print(f' {k} = {v}')
Matthew Haddon848efb42021-09-09 12:30:53 +01001593
1594 @staticmethod
1595 def evWrongInputType(check=False, **kwargs):
Les Bell0e027d42021-11-09 14:42:14 +00001596 error_result = False
Matthew Haddon848efb42021-09-09 12:30:53 +01001597
1598 # Find the unsupported input data types
Matthew Haddon848efb42021-09-09 12:30:53 +01001599 op = kwargs['op']
1600 input_dtypes = op['types']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001601 allowed_input_dtypes = {t[0] if isinstance(t, list) else t for t in input_dtypes}
Les Bell0e027d42021-11-09 14:42:14 +00001602 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
Matthew Haddon848efb42021-09-09 12:30:53 +01001603
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001604 if op['op'] == Op.CLAMP:
1605 wrong_input_dtypes.remove(DType.INT48)
1606
Matthew Haddon848efb42021-09-09 12:30:53 +01001607 if check:
1608 input_dtype = kwargs['input_dtype']
Les Bell0e027d42021-11-09 14:42:14 +00001609 if input_dtype not in allowed_input_dtypes:
Matthew Haddon848efb42021-09-09 12:30:53 +01001610 error_result = True
1611
1612 info_dict = {
Les Bell0e027d42021-11-09 14:42:14 +00001613 "error_name": ErrorIf.WrongInputType,
Matthew Haddon848efb42021-09-09 12:30:53 +01001614 "error_result": error_result,
Les Bell0e027d42021-11-09 14:42:14 +00001615 "error_reason": f"Input data type not supported for this operator",
1616 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
Matthew Haddon848efb42021-09-09 12:30:53 +01001617 }
1618 return info_dict
1619
1620 @staticmethod
1621 def evWrongOutputType(check=False, **kwargs):
Matthew Haddon848efb42021-09-09 12:30:53 +01001622 error_result = False
Matthew Haddon848efb42021-09-09 12:30:53 +01001623
1624 if check:
1625 input_dtype = kwargs['input_dtype']
1626 output_dtype = kwargs['output_dtype']
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001627 op = kwargs['op']
Matthew Haddon848efb42021-09-09 12:30:53 +01001628
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001629 if op['op'] == Op.RESIZE:
1630 mode = kwargs['mode']
1631 if (
1632 (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
1633 (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
1634 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1635 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1636 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1637 ):
1638 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001639
Matthew Haddonc2025212021-10-08 21:21:05 +01001640 elif op['op'] == Op.RESCALE:
1641 if input_dtype == DType.INT8:
1642 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1643 error_result = True
1644 if input_dtype in [DType.INT16, DType.INT32]:
1645 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1646 error_result = True
1647 elif input_dtype == DType.INT48:
1648 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1649 error_result = True
1650 elif input_dtype == DType.UINT8:
1651 if output_dtype != DType.INT8:
1652 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001653
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001654 elif op['op'] in [Op.FULLY_CONNECTED, Op.MATMUL]:
1655 if (
1656 (input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1657 (input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1658 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1659 ):
1660 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001661
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001662 elif op['op'] == Op.ARGMAX:
1663 if input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] and output_dtype != DType.INT32:
1664 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001665
1666 elif op['op'] == Op.MUL:
1667 if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
1668 error_result = True
1669 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1670 error_result = True
1671
1672 elif op['op'] == Op.TABLE:
1673 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
1674 error_result = True
1675 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
1676 error_result = True
1677
1678 elif op['op'] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
1679 if output_dtype != DType.BOOL:
1680 error_result = True
1681
1682 elif op['op'] == Op.CAST:
1683 if (
1684 (input_dtype == DType.BOOL and output_dtype not in [DType.INT8, DType.INT16, DType.INT32])
1685 or (input_dtype == DType.INT8 and output_dtype not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT])
1686 or (input_dtype == DType.INT16 and output_dtype not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT])
1687 or (input_dtype == DType.INT32 and output_dtype not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT])
1688 or (input_dtype == DType.FLOAT and output_dtype not in [DType.INT8, DType.INT16, DType.INT32])
1689 ):
1690 error_result = True
1691
Les Bell0e027d42021-11-09 14:42:14 +00001692 elif op['op'] in {Op.CONV2D, Op.CONV3D, Op.DEPTHWISE_CONV2D, Op.TRANSPOSE_CONV2D}:
1693 if (
1694 input_dtype == DType.INT8 and output_dtype != DType.INT32
1695 or input_dtype == DType.INT16 and output_dtype != DType.INT48
1696 or input_dtype == DType.FLOAT and output_dtype != DType.FLOAT
1697 ):
1698 error_result = True
1699 # invalid input types are ignored, to avoid reporting multiple errors
1700
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001701 else:
1702 if output_dtype != input_dtype:
1703 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001704
1705 info_dict = {
Les Bell0e027d42021-11-09 14:42:14 +00001706 "error_name": ErrorIf.WrongOutputType,
Matthew Haddon848efb42021-09-09 12:30:53 +01001707 "error_result": error_result,
Les Bell0e027d42021-11-09 14:42:14 +00001708 "error_reason": "Output data type not supported for this configuration of operator",
1709 "param_reqs": {"rank": None, "dtype": None, "shape": None}
Matthew Haddon848efb42021-09-09 12:30:53 +01001710 }
1711 return info_dict
1712
1713 @staticmethod
1714 def evWrongRank(check=False, **kwargs):
1715 all_ranks = (1, 2, 3, 4, 5)
1716
1717 # Make a list of incorrect ranks
1718 assert 'op' in kwargs
1719 op = kwargs['op']
1720 rmin, rmax = op['rank']
1721 rank_range = range(rmin, rmax + 1)
1722 incorrect_ranks = list(set(all_ranks) - set(rank_range))
Matthew Haddonc2025212021-10-08 21:21:05 +01001723 # Remove small incorrect ranks to avoid index errors
1724 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
Matthew Haddon848efb42021-09-09 12:30:53 +01001725 # Set minimum incorrect rank to 3 to avoid index error
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001726 if op['op'] in [Op.RESIZE]:
Matthew Haddon848efb42021-09-09 12:30:53 +01001727 incorrect_ranks = [3, 5]
Les Bell0e027d42021-11-09 14:42:14 +00001728 elif op['op'] in [Op.TRANSPOSE]:
Matthew Haddon01c359d2021-10-15 16:30:48 +01001729 incorrect_ranks = [7, 8]
Les Bell0e027d42021-11-09 14:42:14 +00001730 elif op['op'] in [Op.CONV3D]:
1731 incorrect_ranks = [6, 7]
Matthew Haddon848efb42021-09-09 12:30:53 +01001732
1733 error_name = ErrorIf.WrongRank
1734 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1735 error_result = False
1736 error_reason = "Rank not supported for this operator"
1737
1738 if check:
1739 input_shape = kwargs['input_shape']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001740
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001741 if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4:
Matthew Haddon848efb42021-09-09 12:30:53 +01001742 error_result = True
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001743 elif op['op'] == Op.FULLY_CONNECTED and len(input_shape) != 2:
1744 error_result = True
1745 elif op['op'] == Op.MATMUL and len(input_shape) != 3:
1746 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001747 else:
1748 if len(input_shape) not in rank_range:
1749 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001750
1751 info_dict = {
1752 "error_name": error_name,
1753 "error_result": error_result,
1754 "error_reason": error_reason,
1755 "param_reqs": param_reqs
1756 }
1757 return info_dict
1758
1759 @staticmethod
1760 def evWrongInputList(check=False, **kwargs):
1761 error_name = ErrorIf.WrongInputList
1762 param_reqs = {"rank": None, "dtype": None, "shape": None}
1763 error_result = False
1764 error_reason = "Op input list does not match expected input"
1765
1766 if check:
1767 op = kwargs['op']
1768 input_list = kwargs['input_list']
1769 num_operands = kwargs['num_operands']
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001770 if op['op'] in [Op.SCATTER, Op.GATHER]:
1771 # SCATTER/GATHER add an indices input tensor in their build functions
1772 num_operands += 1
Kevin Chengfe392ce2021-10-18 21:51:55 +00001773 if len(input_list) != num_operands:
1774 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001775
1776 info_dict = {
1777 "error_name": error_name,
1778 "error_result": error_result,
1779 "error_reason": error_reason,
1780 "param_reqs": param_reqs
1781 }
1782 return info_dict
1783
1784 @staticmethod
1785 def evWrongOutputList(check=False, **kwargs):
1786 error_name = ErrorIf.WrongOutputList
1787 param_reqs = {"rank": None, "dtype": None, "shape": None}
1788 error_result = False
1789 error_reason = "Op output list does not match expected output"
1790
1791 if check:
1792 output_list = kwargs['output_list']
1793 # Note this will be incorrect if an operator returns more than one output
1794 if len(output_list) != 1:
1795 error_result = True
1796
1797 info_dict = {
1798 "error_name": error_name,
1799 "error_result": error_result,
1800 "error_reason": error_reason,
1801 "param_reqs": param_reqs
1802 }
1803 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01001804
1805 @staticmethod
1806 def evMaxDimExceeded(check=False, **kwargs):
1807 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01001808 param_reqs = {
1809 "rank": [4,4],
1810 "dtype": [DType.INT8],
1811 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
1812 }
Matthew Haddone86fd342021-09-07 16:12:21 +01001813 error_result = False
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001814 error_reason = "At least one maximum dimension is greater than or equal to 16384"
Matthew Haddone86fd342021-09-07 16:12:21 +01001815
1816 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001817 input_shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001818 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001819 if ((input_shape[1] >= 16384) or
1820 (input_shape[2] >= 16384) or
1821 (output_shape[0] >= 16384) or
1822 (output_shape[1] >= 16384)):
Matthew Haddone86fd342021-09-07 16:12:21 +01001823 error_result = True
1824
1825 info_dict = {
1826 "error_name": error_name,
1827 "error_result": error_result,
1828 "error_reason": error_reason,
1829 "param_reqs": param_reqs
1830 }
1831 return info_dict
1832
1833 @staticmethod
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001834 def evBatchMismatch(check=False, **kwargs):
1835 error_name = ErrorIf.BatchMismatch
1836 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1837 error_result = False
1838 error_reason = "Input batch size not equal to output batch size"
1839
1840 assert 'op' in kwargs
1841 op = kwargs['op']
1842 rmin, rmax = op['rank']
1843 rank_range = range(rmin, rmax + 1)
1844
1845 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001846 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001847 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1848
1849 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
1850 error_result = True
1851
1852 info_dict = {
1853 "error_name": error_name,
1854 "error_result": error_result,
1855 "error_reason": error_reason,
1856 "param_reqs": param_reqs
1857 }
1858 return info_dict
1859
1860 @staticmethod
1861 def evChannelMismatch(check=False, **kwargs):
1862 error_name = ErrorIf.ChannelMismatch
1863 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1864 error_result = False
1865 error_reason = "Input channel size not equal to output channel size"
1866
1867 assert 'op' in kwargs
1868 op = kwargs['op']
1869 rmin, rmax = op['rank']
1870 rank_range = range(rmin, rmax + 1)
1871
1872 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001873 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001874 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1875 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
1876 error_result = True
1877
1878 info_dict = {
1879 "error_name": error_name,
1880 "error_result": error_result,
1881 "error_reason": error_reason,
1882 "param_reqs": param_reqs
1883 }
1884 return info_dict
1885
1886 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001887 def evStrideSmallerEqualZero(check=False, **kwargs):
1888 error_name = ErrorIf.StrideSmallerEqualZero
1889 param_reqs = {"rank": None, "dtype": None, "shape": None}
1890 error_result = False
1891 error_reason = "Stride value smaller than or equal zero"
1892
1893 if check:
1894 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001895 output_dtype = kwargs['output_dtype']
1896 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
1897 stride = kwargs['stride'] # Work around wrong input/output type tests
1898 elif output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001899 stride = kwargs['stride_fp']
Matthew Haddon848efb42021-09-09 12:30:53 +01001900 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1901 stride = kwargs['stride_fp'] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01001902 else:
1903 stride = kwargs['stride']
1904
1905 if min(stride) <= 0:
1906 error_result = True
1907
1908 info_dict = {
1909 "error_name": error_name,
1910 "error_result": error_result,
1911 "error_reason": error_reason,
1912 "param_reqs": param_reqs
1913 }
1914 return info_dict
1915
1916 @staticmethod
1917 def evStrideLargerEqualMax(check=False, **kwargs):
1918 error_name = ErrorIf.StrideLargerEqualMax
1919 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1920 error_result = False
1921 error_reason = "Stride value larger than or equal to maximum value"
1922
1923 if check:
1924 shift = kwargs['shift']
1925 input_dtype = kwargs['input_dtype']
1926 stride = kwargs['stride']
1927 if input_dtype in [DType.INT8, DType.INT16]:
1928 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1929 error_result = True
1930 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1931 error_result = True
1932
1933 info_dict = {
1934 "error_name": error_name,
1935 "error_result": error_result,
1936 "error_reason": error_reason,
1937 "param_reqs": param_reqs
1938 }
1939 return info_dict
1940
1941
1942 @staticmethod
1943 def evStrideLargerDimension(check=False, **kwargs):
1944 error_name = ErrorIf.StrideLargerDimension
1945 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1946 error_result = False
1947 error_reason = "Stride value larger than or equal to H/W dimension"
1948
1949 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001950 shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001951 input_dtype = kwargs['input_dtype']
1952 stride = kwargs['stride_fp']
1953
1954 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1955 error_result = True
1956
1957 info_dict = {
1958 "error_name": error_name,
1959 "error_result": error_result,
1960 "error_reason": error_reason,
1961 "param_reqs": param_reqs
1962 }
1963 return info_dict
1964
1965
1966 @staticmethod
1967 def evOffsetSmallerEqualMin(check=False, **kwargs):
1968 error_name = ErrorIf.OffsetSmallerEqualMin
1969 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1970 error_result = False
1971 error_reason = "Offset value smaller than or equal to minimum value"
1972
1973 if check:
1974 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001975 output_dtype = kwargs['output_dtype']
1976 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001977 offset = kwargs['offset_fp']
1978 else:
1979 offset = kwargs['offset']
1980
1981 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
1982 error_result = True
1983 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
1984 error_result = True
1985
1986 info_dict = {
1987 "error_name": error_name,
1988 "error_result": error_result,
1989 "error_reason": error_reason,
1990 "param_reqs": param_reqs
1991 }
1992 return info_dict
1993
1994 @staticmethod
1995 def evOffsetLargerEqualMax(check=False, **kwargs):
1996 error_name = ErrorIf.OffsetLargerEqualMax
1997 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1998 error_result = False
1999 error_reason = "Offset value larger than or equal to maximum value"
2000
2001 if check:
2002 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01002003 output_dtype = kwargs['output_dtype']
2004 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01002005 offset = kwargs['offset_fp']
2006 else:
2007 offset = kwargs['offset']
2008
2009 if shift >= 0:
2010 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
2011 error_result = True
2012
2013 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
2014 error_result = True
2015 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
2016 error_result = True
2017
2018 info_dict = {
2019 "error_name": error_name,
2020 "error_result": error_result,
2021 "error_reason": error_reason,
2022 "param_reqs": param_reqs
2023 }
2024 return info_dict
2025
2026 @staticmethod
2027 def evShiftNotZero(check=False, **kwargs):
2028 error_name = ErrorIf.ShiftNotZero
2029 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
2030 error_result = False
2031 error_reason = "Shift value must be zero for float input"
2032
2033 if check:
2034 shift = kwargs['shift']
2035 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01002036 output_dtype = kwargs['output_dtype']
2037 if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
Matthew Haddone86fd342021-09-07 16:12:21 +01002038 error_result = True
2039
2040 info_dict = {
2041 "error_name": error_name,
2042 "error_result": error_result,
2043 "error_reason": error_reason,
2044 "param_reqs": param_reqs
2045 }
2046 return info_dict
2047
2048
2049 @staticmethod
2050 def evShiftSmallerOne(check=False, **kwargs):
2051 error_name = ErrorIf.ShiftSmallerOne
2052 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2053 error_result = False
2054 error_reason = "Shift value smaller than one"
2055
2056 if check:
2057 shift = kwargs['shift']
2058 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01002059 output_dtype = kwargs['output_dtype']
2060 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01002061 error_result = True
2062
2063 info_dict = {
2064 "error_name": error_name,
2065 "error_result": error_result,
2066 "error_reason": error_reason,
2067 "param_reqs": param_reqs
2068 }
2069 return info_dict
2070
2071 @staticmethod
2072 def evShiftLargerEleven(check=False, **kwargs):
2073 error_name = ErrorIf.ShiftLargerEleven
2074 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2075 error_result = False
2076 error_reason = "Shift value larger than eleven"
2077
2078 if check:
2079 shift = kwargs['shift']
2080 if shift > 11:
2081 error_result = True
2082
2083 info_dict = {
2084 "error_name": error_name,
2085 "error_result": error_result,
2086 "error_reason": error_reason,
2087 "param_reqs": param_reqs
2088 }
2089 return info_dict
2090
2091
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002092 @staticmethod
2093 def evRankMismatch(check=False, **kwargs):
2094 error_name = ErrorIf.RankMismatch
2095 param_reqs = {"rank": None, "dtype": None, "shape": None}
2096 error_result = False
2097 error_reason = "Input Rank does not match output rank"
2098
2099 if check:
2100 input1_shape = kwargs['input1'].shape
2101 input2_shape = kwargs['input2'].shape
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002102 # In case of SELECT op
2103 input3_shape = kwargs['input3'].shape if 'input3' in kwargs else input2_shape
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002104 output_shape = kwargs['result_tensor'].shape
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002105 if (
2106 (len(input1_shape) != len(output_shape)) or
2107 (len(input2_shape) != len(output_shape)) or
2108 (len(input3_shape) != len(output_shape))
2109 ):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002110 error_result = True
2111
2112 info_dict = {
2113 "error_name": error_name,
2114 "error_result": error_result,
2115 "error_reason": error_reason,
2116 "param_reqs": param_reqs
2117 }
2118 return info_dict
2119
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002120 @staticmethod
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002121 def evDimensionMismatch(check=False, **kwargs):
2122 error_name = ErrorIf.DimensionMismatch
2123 param_reqs = {"rank": None, "dtype": None, "shape": None}
2124 error_result = False
2125 error_reason = "Input Dimensions do not match output"
2126
2127 if check:
2128 input1_shape = kwargs['input1'].shape
2129 input2_shape = kwargs['input2'].shape
2130 # In case of SELECT op
2131 input3_shape = kwargs['input3'].shape if 'input3' in kwargs else input2_shape
2132 output_shape = kwargs['result_tensor'].shape
2133 for i in range(min(len(input1_shape), len(input2_shape), len(input3_shape))):
2134 if (
2135 (input1_shape[i] != 1 and input1_shape[i] != output_shape[i]) or
2136 (input2_shape[i] != 1 and input2_shape[i] != output_shape[i]) or
2137 (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
2138 ):
2139 error_result = True
2140
2141 info_dict = {
2142 "error_name": error_name,
2143 "error_result": error_result,
2144 "error_reason": error_reason,
2145 "param_reqs": param_reqs
2146 }
2147 return info_dict
2148
2149 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002150 def evInputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002151 op = kwargs['op']
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002152 error_result = False
Les Bell0e027d42021-11-09 14:42:14 +00002153
2154 # Quantizable types
2155 qTypes = (DType.INT8, DType.UINT8)
2156
2157 # This does not apply to quantizable types
2158 inputDtypes = [
2159 dtype for dtype in op['types']
2160 if (isinstance(dtype, list) and dtype[0] not in qTypes) or
2161 (not isinstance(dtype, list) and dtype not in qTypes)
2162 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002163
2164 if check:
2165 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01002166 if isinstance(kwargs['qinfo'], tuple):
2167 qinfo = kwargs['qinfo']
2168 input_zero_point = qinfo[0]
2169 else:
2170 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
2171 qinfo = kwargs['qinfo'].ints
2172 input_zero_point = qinfo[0][1]
2173
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002174 if op['op'] == Op.MATMUL:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002175 qinfo = kwargs['qinfo'].ints
Les Bell0e027d42021-11-09 14:42:14 +00002176 for dtype, zp in (
2177 (kwargs['input_dtype'], qinfo[0][1]),
2178 (kwargs['input2_dtype'], qinfo[1][1]),
2179 ):
2180 if dtype not in qTypes and zp != 0:
2181 error_result = True
2182 break
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002183 else:
Les Bell0e027d42021-11-09 14:42:14 +00002184 error_result = input_dtype not in qTypes and input_zero_point != 0
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002185
2186 info_dict = {
Les Bell0e027d42021-11-09 14:42:14 +00002187 "error_name": ErrorIf.InputZeroPointNotZero,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002188 "error_result": error_result,
Les Bell0e027d42021-11-09 14:42:14 +00002189 "error_reason": "Input DType not INT8 and zero point not 0",
2190 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None}
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002191 }
2192 return info_dict
2193
2194
2195 @staticmethod
2196 def evWeightZeroPointNotZero(check=False, **kwargs):
2197 op = kwargs['op']
2198
2199 # exclude inputs with INT8 weights
2200 inputDtypes = [t for t in op['types']
2201 if not isinstance(t, list) or t[1] != DType.INT8]
2202
2203 error_name = ErrorIf.WeightZeroPointNotZero
2204 param_reqs = {
2205 "rank": None,
2206 "dtype": inputDtypes,
2207 "shape": None
2208 }
2209 error_result = False
2210 error_reason = "Weight DType not INT8 and zero point not 0"
2211
2212 if check:
2213 weight_dtype = kwargs['weight_dtype']
2214 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
2215 qinfo = kwargs['qinfo'].ints
2216 weight_zero_point = qinfo[1][1]
2217 if weight_dtype != DType.INT8 and weight_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002218 error_result = True
2219
2220 info_dict = {
2221 "error_name": error_name,
2222 "error_result": error_result,
2223 "error_reason": error_reason,
2224 "param_reqs": param_reqs
2225 }
2226 return info_dict
2227
2228
2229 @staticmethod
2230 def evOutputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002231 op = kwargs['op']
2232 inputDtypes = op['types'].copy()
2233 if DType.INT8 in inputDtypes:
2234 inputDtypes.remove(DType.INT8)
2235 if DType.UINT8 in inputDtypes:
2236 inputDtypes.remove(DType.UINT8)
2237
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002238 error_name = ErrorIf.OutputZeroPointNotZero
2239 param_reqs = {
2240 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002241 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002242 "shape": None
2243 }
2244 error_result = False
2245 error_reason = "Output DType not INT8 and zero point not 0"
2246
2247 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002248 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01002249 output_dtype = kwargs['output_dtype']
2250 if isinstance(kwargs['qinfo'], tuple):
2251 qinfo = kwargs['qinfo']
2252 output_zero_point = qinfo[1]
2253 else:
2254 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
2255 qinfo = kwargs['qinfo'].ints
2256 output_zero_point = qinfo[1][1]
2257 if op['op'] == Op.AVG_POOL2D:
2258 if input_dtype != DType.INT8 and output_zero_point != 0:
2259 error_result = True
2260 elif output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002261 error_result = True
2262
2263 info_dict = {
2264 "error_name": error_name,
2265 "error_result": error_result,
2266 "error_reason": error_reason,
2267 "param_reqs": param_reqs
2268 }
2269 return info_dict
2270
Matthew Haddond6ce7252021-09-29 15:35:44 +01002271 @staticmethod
2272 def evAxisSmallerZero(check=False, **kwargs):
2273 error_name = ErrorIf.AxisSmallerZero
2274 param_reqs = {"rank": None, "dtype": None, "shape": None}
2275 error_result = False
2276 error_reason = "Axis smaller than zero"
2277
2278 if check:
2279 axis = kwargs['axis']
2280 if axis < 0:
2281 error_result = True
2282
2283 info_dict = {
2284 "error_name": error_name,
2285 "error_result": error_result,
2286 "error_reason": error_reason,
2287 "param_reqs": param_reqs
2288 }
2289 return info_dict
2290
2291
2292 @staticmethod
2293 def evAxisLargerRank(check=False, **kwargs):
2294 error_name = ErrorIf.AxisLargerRank
2295 param_reqs = {"rank": None, "dtype": None, "shape": None}
2296 error_result = False
2297 error_reason = "Axis larger than rank"
2298
2299 if check:
2300 axis = kwargs['axis']
2301 shape = kwargs['input_shape']
2302 if axis > len(shape):
2303 error_result = True
2304
2305 info_dict = {
2306 "error_name": error_name,
2307 "error_result": error_result,
2308 "error_reason": error_reason,
2309 "param_reqs": param_reqs
2310 }
2311 return info_dict
2312
2313
2314 @staticmethod
2315 def evShapeOfAxisNotOne(check=False, **kwargs):
2316 error_name = ErrorIf.ShapeOfAxisNotOne
2317 param_reqs = {"rank": None, "dtype": None, "shape": None}
2318 error_result = False
2319 error_reason = "shape[axis] is not equal to 1"
2320
2321 if check:
2322 axis = kwargs['axis']
2323 shape = kwargs['output_shape']
2324 if (0 <= axis < len(shape)) and shape[axis] != 1:
2325 error_result = True
2326
2327 info_dict = {
2328 "error_name": error_name,
2329 "error_result": error_result,
2330 "error_reason": error_reason,
2331 "param_reqs": param_reqs
2332 }
2333 return info_dict
2334
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002335
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002336 @staticmethod
2337 def evPadSmallerZero(check=False, **kwargs):
2338 error_name = ErrorIf.PadSmallerZero
2339 param_reqs = {"rank": None, "dtype": None, "shape": None}
2340 error_result = False
2341 error_reason = "At least one pad is smaller than zero"
2342
2343 if check:
Matthew Haddone807aae2021-10-11 18:12:58 +01002344 op = kwargs['op']
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002345 pad = kwargs['pad']
Matthew Haddone807aae2021-10-11 18:12:58 +01002346 if op['op'] == Op.PAD:
2347 for padding in pad:
2348 if min(padding) < 0:
2349 error_result = True
2350 else:
2351 if min(pad) < 0:
2352 error_result = True
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002353
2354 info_dict = {
2355 "error_name": error_name,
2356 "error_result": error_result,
2357 "error_reason": error_reason,
2358 "param_reqs": param_reqs
2359 }
2360 return info_dict
2361
2362
2363 @staticmethod
2364 def evPadLargerEqualKernel(check=False, **kwargs):
2365 error_name = ErrorIf.PadLargerEqualKernel
2366 param_reqs = {"rank": None, "dtype": None, "shape": None}
2367 error_result = False
2368 error_reason = "At least one pad is larger than kernel dimension"
2369
2370 if check:
2371 pad = kwargs['pad']
2372 kernel = kwargs['kernel']
2373 if min(pad) > 0 and min(kernel) > 1:
2374 if pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]:
2375 error_result = True
2376
2377 info_dict = {
2378 "error_name": error_name,
2379 "error_result": error_result,
2380 "error_reason": error_reason,
2381 "param_reqs": param_reqs
2382 }
2383 return info_dict
2384
2385 @staticmethod
2386 def evPoolingOutputShapeMismatch(check=False, **kwargs):
2387 error_name = ErrorIf.PoolingOutputShapeMismatch
2388 param_reqs = {"rank": None, "dtype": None, "shape": None}
2389 error_result = False
2390 error_reason = "Mismatch between output shape provided and expected output shape"
2391
2392 if check:
2393 pad = kwargs['pad']
2394 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
2395
2396 kernel = kwargs['kernel']
2397 kernel_y, kernel_x = kernel[0], kernel[1]
2398
2399 input_shape = kwargs['input_shape']
2400 IH, IW = input_shape[1], input_shape[2]
2401
2402 output_shape = kwargs['output_shape']
2403 OH, OW = output_shape[1], output_shape[2]
2404
2405 stride = kwargs['stride']
2406 stride_y, stride_x = stride[0], stride[1]
2407
2408 # calculate correct height, width dimensions
2409 if stride_x != 0 and stride_y != 0:
2410 y_correct = (IH + pad_top + pad_bottom + stride_y - kernel_y) // stride_y
2411 x_correct = (IW + pad_left + pad_right + stride_x - kernel_x) // stride_x
2412
2413 # ensure parameters are valid
2414 params_valid = (min(kernel) >= 1 and min(stride) >= 1 and min(pad) >= 0
2415 and not (pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]))
2416
2417 if params_valid and (OH != y_correct or OW != x_correct):
2418 error_result = True
2419
2420 info_dict = {
2421 "error_name": error_name,
2422 "error_result": error_result,
2423 "error_reason": error_reason,
2424 "param_reqs": param_reqs
2425 }
2426 return info_dict
2427
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002428 @staticmethod
2429 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
2430 error_name = ErrorIf.ArgmaxOutputShapeMismatch
2431 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2432 error_result = False
2433 error_reason = "Mismatch between output shape provided and expected output shape"
2434
2435 if check:
2436 output_shape = kwargs['output_shape']
2437 input_shape = kwargs['input_shape']
2438 axis = kwargs['axis']
2439
2440 dimension_match = True
2441 axis_shift = 0
2442
2443 # Check that rank is correct before trying to check dimensions
2444 if (len(input_shape) - 1) == len(output_shape):
2445 for i in range(len(input_shape)):
2446 if i == axis:
2447 axis_shift = 1
2448 continue
2449 if input_shape[i] != output_shape[i - axis_shift]:
2450 dimension_match = False
2451
2452 if not dimension_match:
2453 error_result = True
2454
2455 info_dict = {
2456 "error_name": error_name,
2457 "error_result": error_result,
2458 "error_reason": error_reason,
2459 "param_reqs": param_reqs
2460 }
2461 return info_dict
2462
2463 @staticmethod
2464 def evArgmaxOutputRankMismatch(check=False, **kwargs):
2465 error_name = ErrorIf.ArgmaxOutputRankMismatch
2466 param_reqs = {"rank": None, "dtype": None, "shape": None}
2467 error_result = False
2468 error_reason = "Mismatch between output shape provided and expected output shape"
2469
2470 if check:
2471 output_shape = kwargs['output_shape']
2472 input_shape = kwargs['input_shape']
2473 axis = kwargs['axis']
2474 valid_params = axis >= 0 and axis < len(input_shape)
2475
2476 if valid_params and (len(input_shape) - 1) != len(output_shape):
2477 error_result = True
2478
2479 info_dict = {
2480 "error_name": error_name,
2481 "error_result": error_result,
2482 "error_reason": error_reason,
2483 "param_reqs": param_reqs
2484 }
2485 return info_dict
2486
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002487
2488 @staticmethod
2489 def evKernelSmallerOne(check=False, **kwargs):
2490 error_name = ErrorIf.KernelSmallerOne
2491 param_reqs = {"rank": None, "dtype": None, "shape": None}
2492 error_result = False
2493 error_reason = "At least one kernel dimension is smaller than zero"
2494
2495 if check:
2496 kernel = kwargs['kernel']
2497 if min(kernel) < 1:
2498 error_result = True
2499
2500 info_dict = {
2501 "error_name": error_name,
2502 "error_result": error_result,
2503 "error_reason": error_reason,
2504 "param_reqs": param_reqs
2505 }
2506 return info_dict
2507
2508 @staticmethod
2509 def evStrideSmallerOne(check=False, **kwargs):
2510 error_name = ErrorIf.StrideSmallerOne
2511 param_reqs = {"rank": None, "dtype": None, "shape": None}
2512 error_result = False
2513 error_reason = "At least one stride dimension is smaller than zero"
2514
2515 if check:
2516 stride = kwargs['stride']
2517 if min(stride) < 1:
2518 error_result = True
2519
2520 info_dict = {
2521 "error_name": error_name,
2522 "error_result": error_result,
2523 "error_reason": error_reason,
2524 "param_reqs": param_reqs
2525 }
2526 return info_dict
2527
Matthew Haddonc2025212021-10-08 21:21:05 +01002528 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00002529 def evDilationSmallerOne(check=False, **kwargs):
2530 error_result = check and min(kwargs['dilation']) < 1
2531 return {
2532 "error_name": ErrorIf.DilationSmallerOne,
2533 "error_reason": "At least one dilation is smaller than one",
2534 "param_reqs": {"rank": None, "dtype": None, "shape": None},
2535 "error_result": error_result
2536 }
2537
2538 @staticmethod
Matthew Haddonc2025212021-10-08 21:21:05 +01002539 def evScaleTrue(check=False, **kwargs):
2540 error_name = ErrorIf.ScaleTrue
2541 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
2542 error_result = False
2543 error_reason = "Scale set to true but input type is INT48"
2544
2545 if check:
2546 input_dtype = kwargs['input_dtype']
2547 scale32 = kwargs['scale32']
2548 if scale32 and input_dtype == DType.INT48:
2549 error_result = True
2550
2551 info_dict = {
2552 "error_name": error_name,
2553 "error_result": error_result,
2554 "error_reason": error_reason,
2555 "param_reqs": param_reqs
2556 }
2557 return info_dict
2558
2559 @staticmethod
2560 def evScaleNotTrue(check=False, **kwargs):
2561 error_name = ErrorIf.ScaleNotTrue
2562 param_reqs = {"rank": None, "dtype": None, "shape": None}
2563 error_result = False
2564 error_reason = "Scale set to false but double round set to true"
2565
2566 if check:
2567 scale32 = kwargs['scale32']
2568 double_round = kwargs['double_round']
2569 if not scale32 and double_round:
2570 error_result = True
2571
2572 info_dict = {
2573 "error_name": error_name,
2574 "error_result": error_result,
2575 "error_reason": error_reason,
2576 "param_reqs": param_reqs
2577 }
2578 return info_dict
2579
Matthew Haddone807aae2021-10-11 18:12:58 +01002580 @staticmethod
2581 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
2582 error_name = ErrorIf.TensorSizeInputOutputMismatch
2583 param_reqs = {"rank": None, "dtype": None, "shape": None}
2584 error_result = False
2585 error_reason = "Input tensor size does not match output tensor size"
2586
2587 if check:
2588 input_shape = kwargs['input_shape']
2589 output_shape = kwargs['output_shape']
2590 input_size = np.prod(input_shape)
2591 output_size = np.prod(output_shape)
2592 if input_size != output_size:
2593 error_result = True
2594
2595 info_dict = {
2596 "error_name": error_name,
2597 "error_result": error_result,
2598 "error_reason": error_reason,
2599 "param_reqs": param_reqs
2600 }
2601 return info_dict
2602
2603 @staticmethod
2604 def evStartSmallerZero(check=False, **kwargs):
2605 error_name = ErrorIf.StartSmallerZero
2606 param_reqs = {"rank": None, "dtype": None, "shape": None}
2607 error_result = False
2608 error_reason = "Starting point smaller than zero"
2609
2610 if check:
2611 input_shape = kwargs['input_shape']
2612 start = kwargs['start']
2613 rank = len(input_shape)
2614 if len(start) == rank:
2615 for index in range(rank):
2616 if start[index] < 0:
2617 error_result = True
2618
2619 info_dict = {
2620 "error_name": error_name,
2621 "error_result": error_result,
2622 "error_reason": error_reason,
2623 "param_reqs": param_reqs
2624 }
2625 return info_dict
2626
2627
2628 @staticmethod
2629 def evSizeSmallerEqualZero(check=False, **kwargs):
2630 error_name = ErrorIf.SizeSmallerEqualZero
2631 param_reqs = {"rank": None, "dtype": None, "shape": None}
2632 error_result = False
2633 error_reason = "Size smaller than or equal to zero"
2634
2635 if check:
2636 input_shape = kwargs['input_shape']
2637 size = kwargs['size']
2638 rank = len(input_shape)
2639 if len(size) == rank:
2640 for index in range(rank):
2641 if size[index] <= 0:
2642 error_result = True
2643
2644 info_dict = {
2645 "error_name": error_name,
2646 "error_result": error_result,
2647 "error_reason": error_reason,
2648 "param_reqs": param_reqs
2649 }
2650 return info_dict
2651
2652
2653 @staticmethod
2654 def evStartSizeOutsideBounds(check=False, **kwargs):
2655 error_name = ErrorIf.StartSizeOutsideBounds
2656 param_reqs = {"rank": None, "dtype": None, "shape": None}
2657 error_result = False
2658 error_reason = "starting point plus size larger than input dimension"
2659
2660 if check:
2661 input_shape = kwargs['input_shape']
2662 start = kwargs['start']
2663 size = kwargs['size']
2664 rank = len(input_shape)
2665 if len(start) == rank and len(size) == rank:
2666 for index in range(rank):
2667 if start[index] + size[index] > input_shape[index]:
2668 error_result = True
2669
2670 info_dict = {
2671 "error_name": error_name,
2672 "error_result": error_result,
2673 "error_reason": error_reason,
2674 "param_reqs": param_reqs
2675 }
2676 return info_dict
2677
2678
2679 @staticmethod
2680 def evSizeOutputShapeMismatch(check=False, **kwargs):
2681 error_name = ErrorIf.SizeOutputShapeMismatch
2682 param_reqs = {"rank": None, "dtype": None, "shape": None}
2683 error_result = False
2684 error_reason = "Size does not match output dimension"
2685
2686 if check:
2687 input_shape = kwargs['input_shape']
2688 output_shape = kwargs['output_shape']
2689 size = kwargs['size']
2690 rank = len(input_shape)
2691 if len(size) == rank:
2692 for index in range(rank):
2693 if size[index] != output_shape[index]:
2694 error_result = True
2695
2696 info_dict = {
2697 "error_name": error_name,
2698 "error_result": error_result,
2699 "error_reason": error_reason,
2700 "param_reqs": param_reqs
2701 }
2702 return info_dict
2703
2704 @staticmethod
2705 def evInputSizeStartLengthMismatch(check=False, **kwargs):
2706 error_name = ErrorIf.InputSizeStartLengthMismatch
2707 param_reqs = {"rank": None, "dtype": None, "shape": None}
2708 error_result = False
2709 error_reason = "rank of input not equal to length of start or size"
2710
2711 if check:
2712 input_shape = kwargs['input_shape']
2713 start = kwargs['start']
2714 size = kwargs['size']
2715 rank = len(input_shape)
2716 if rank != len(start) or rank != len(size):
2717 error_result = True
2718
2719 info_dict = {
2720 "error_name": error_name,
2721 "error_result": error_result,
2722 "error_reason": error_reason,
2723 "param_reqs": param_reqs
2724 }
2725 return info_dict
2726
2727 @staticmethod
2728 def evIndexOutsideBounds(check=False, **kwargs):
2729 error_name = ErrorIf.IndexOutsideBounds
2730 param_reqs = {"rank": None, "dtype": None, "shape": None}
2731 error_result = False
2732 error_reason = "Index outside of allowed bounds"
2733
2734 if check:
2735 input_shape = kwargs['input_shape']
2736 perms = kwargs['perms']
2737 rank = len(input_shape)
2738
2739 for index in perms:
2740 if index < 0 or index > rank:
2741 error_result = True
2742
2743 info_dict = {
2744 "error_name": error_name,
2745 "error_result": error_result,
2746 "error_reason": error_reason,
2747 "param_reqs": param_reqs
2748 }
2749 return info_dict
2750
2751 @staticmethod
2752 def evIndexUsedTwice(check=False, **kwargs):
2753 error_name = ErrorIf.IndexUsedTwice
2754 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2755 error_result = False
2756 error_reason = "Index used multiple times"
2757
2758 if check:
2759 input_shape = kwargs['input_shape']
2760 perms = kwargs['perms']
2761 rank = len(input_shape)
2762
2763 unique_indices = []
2764 for index in perms:
2765 if index in unique_indices:
2766 error_result = True
2767 else:
2768 unique_indices.append(index)
2769
2770 info_dict = {
2771 "error_name": error_name,
2772 "error_result": error_result,
2773 "error_reason": error_reason,
2774 "param_reqs": param_reqs
2775 }
2776 return info_dict
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002777
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002778 @staticmethod
2779 def evMaxSmallerMin(check=False, **kwargs):
2780 error_name = ErrorIf.MaxSmallerMin
2781 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2782 error_result = False
2783 error_reason = "Max value smaller than min value"
2784
2785 if check:
2786 max_val = kwargs['max_val']
2787 min_val = kwargs['min_val']
2788 if max_val < min_val:
2789 error_result = True
2790
2791
2792 info_dict = {
2793 "error_name": error_name,
2794 "error_result": error_result,
2795 "error_reason": error_reason,
2796 "param_reqs": param_reqs
2797 }
2798 return info_dict
2799
2800 @staticmethod
2801 def evConcatInputRankMismatch(check=False, **kwargs):
2802 error_name = ErrorIf.ConcatInputRankMismatch
2803 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2804 error_result = False
2805 error_reason = "Input ranks are not identical"
2806
2807 if check:
2808 inputs = kwargs['inputs']
2809 input_shape = kwargs['input_shape']
2810 for input in inputs:
2811 if len(input.shape) != len(input_shape):
2812 error_result = True
2813
2814 info_dict = {
2815 "error_name": error_name,
2816 "error_result": error_result,
2817 "error_reason": error_reason,
2818 "param_reqs": param_reqs
2819 }
2820 return info_dict
2821
2822 @staticmethod
2823 def evConcatInputDimMismatch(check=False, **kwargs):
2824 error_name = ErrorIf.ConcatInputDimMismatch
2825 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2826 error_result = False
2827 error_reason = "Input dimensions differ on too many axes"
2828
2829 if check:
2830 inputs = kwargs['inputs']
2831 input_shape = kwargs['input_shape']
2832 axis = kwargs['axis']
2833
2834 # Ensure rank is valid before checking dims.
2835 valid_rank = True
2836 for input in inputs:
2837 if len(input.shape) != len(input_shape):
2838 valid_rank = False
2839
2840 if valid_rank:
2841 for input in inputs:
2842 for i, dim in enumerate(input.shape):
2843 if dim != input_shape[i] and axis != i:
2844 error_result = True
2845
2846 info_dict = {
2847 "error_name": error_name,
2848 "error_result": error_result,
2849 "error_reason": error_reason,
2850 "param_reqs": param_reqs
2851 }
2852 return info_dict
2853
Matthew Haddon630c17c2021-10-14 15:05:41 +01002854 @staticmethod
Matthew Haddon01c359d2021-10-15 16:30:48 +01002855 def evConcatShapeSumMismatch(check=False, **kwargs):
2856 error_name = ErrorIf.ConcatShapeSumMismatch
2857 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2858 error_result = False
2859 error_reason = "Sum of dimensions on axis not equal to output dimension"
2860
2861 if check:
2862 inputs = kwargs['inputs']
2863 input_shape = kwargs['input_shape']
2864 output_shape = kwargs['output_shape']
2865 axis = kwargs['axis']
2866
2867 # Ensure rank is valid before checking dims.
2868 valid_params = True
2869 for input in inputs:
2870 if len(input.shape) != len(input_shape):
2871 valid_params = False
2872 if axis < 0 or axis > len(input_shape):
2873 valid_params = False
2874
2875 if valid_params:
2876 axis_dim_sum = 0
2877 for input in inputs:
2878 axis_dim_sum += input.shape[axis]
2879
2880 if axis_dim_sum != output_shape[axis]:
2881 error_result = True
2882
2883
2884 info_dict = {
2885 "error_name": error_name,
2886 "error_result": error_result,
2887 "error_reason": error_reason,
2888 "param_reqs": param_reqs
2889 }
2890 return info_dict
2891
2892 @staticmethod
Matthew Haddon630c17c2021-10-14 15:05:41 +01002893 def evInputListThenGraphMismatch(check=False, **kwargs):
2894 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2895 param_reqs = {"rank": None, "dtype": None, "shape": None}
2896 error_result = False
2897 error_reason = "Input list shape does not match then-graph shape"
2898
2899 if check:
2900 a = kwargs['a']
2901 b = kwargs['b']
2902 basicBlocks = kwargs['basicBlocks']
2903 then_block = basicBlocks[1]
2904 then_inputs = then_block.inputs
2905 then_tens = then_block.tensors
2906 if (a.shape != then_tens[then_inputs[0]].shape) or (b.shape != then_tens[then_inputs[1]].shape):
2907 error_result = True
2908
2909 info_dict = {
2910 "error_name": error_name,
2911 "error_result": error_result,
2912 "error_reason": error_reason,
2913 "param_reqs": param_reqs
2914 }
2915 return info_dict
2916
2917
2918 @staticmethod
2919 def evInputListElseGraphMismatch(check=False, **kwargs):
2920 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2921 param_reqs = {"rank": None, "dtype": None, "shape": None}
2922 error_result = False
2923 error_reason = "Input list shape does not match else-graph shape"
2924
2925 if check:
2926 a = kwargs['a']
2927 b = kwargs['b']
2928 basicBlocks = kwargs['basicBlocks']
2929 else_block = basicBlocks[2]
2930 else_inputs = else_block.inputs
2931 else_tens = else_block.tensors
2932 if (a.shape != else_tens[else_inputs[0]].shape) or (b.shape != else_tens[else_inputs[1]].shape):
2933 error_result = True
2934
2935 info_dict = {
2936 "error_name": error_name,
2937 "error_result": error_result,
2938 "error_reason": error_reason,
2939 "param_reqs": param_reqs
2940 }
2941 return info_dict
2942
2943
2944 @staticmethod
2945 def evOutputListThenGraphMismatch(check=False, **kwargs):
2946 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2947 param_reqs = {"rank": None, "dtype": None, "shape": None}
2948 error_result = False
2949 error_reason = "Output list shape does not match then-graph shape"
2950
2951 if check:
2952 basicBlocks = kwargs['basicBlocks']
2953 cond_block = basicBlocks[0]
2954 cond_outputs = cond_block.outputs
2955 cond_tens = cond_block.tensors
2956 then_block = basicBlocks[1]
2957 then_outputs = then_block.outputs
2958 then_tens = then_block.tensors
2959 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2960 error_result = True
2961
2962 info_dict = {
2963 "error_name": error_name,
2964 "error_result": error_result,
2965 "error_reason": error_reason,
2966 "param_reqs": param_reqs
2967 }
2968 return info_dict
2969
2970
2971 @staticmethod
2972 def evOutputListElseGraphMismatch(check=False, **kwargs):
2973 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2974 param_reqs = {"rank": None, "dtype": None, "shape": None}
2975 error_result = False
2976 error_reason = "Output list shape does not match else-graph shape"
2977
2978 if check:
2979 basicBlocks = kwargs['basicBlocks']
2980 cond_block = basicBlocks[0]
2981 cond_outputs = cond_block.outputs
2982 cond_tens = cond_block.tensors
2983 else_block = basicBlocks[2]
2984 else_outputs = else_block.outputs
2985 else_tens = else_block.tensors
2986 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2987 error_result = True
2988
2989 info_dict = {
2990 "error_name": error_name,
2991 "error_result": error_result,
2992 "error_reason": error_reason,
2993 "param_reqs": param_reqs
2994 }
2995 return info_dict
2996
2997
2998 @staticmethod
2999 def evInputListOutputListMismatch(check=False, **kwargs):
3000 error_name = ErrorIf.InputListOutputListMismatch
3001 param_reqs = {"rank": None, "dtype": None, "shape": None}
3002 error_result = False
3003 error_reason = "Input list does not match output list"
3004
3005 if check:
3006 basicBlocks = kwargs['basicBlocks']
3007 while_block = basicBlocks[0]
3008 while_inputs = while_block.inputs
3009 while_outputs = while_block.outputs
3010 while_tens = while_block.tensors
3011 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
3012 error_result = True
3013
3014 info_dict = {
3015 "error_name": error_name,
3016 "error_result": error_result,
3017 "error_reason": error_reason,
3018 "param_reqs": param_reqs
3019 }
3020 return info_dict
3021
3022
3023 @staticmethod
3024 def evInputListCondGraphMismatch(check=False, **kwargs):
3025 error_name = ErrorIf.InputListCondGraphMismatch
3026 param_reqs = {"rank": None, "dtype": None, "shape": None}
3027 error_result = False
3028 error_reason = "Input list does not match cond graph"
3029
3030 if check:
3031 basicBlocks = kwargs['basicBlocks']
3032 while_block = basicBlocks[0]
3033 while_inputs = while_block.inputs
3034 while_tens = while_block.tensors
3035 cond_block = basicBlocks[1]
3036 cond_inputs = cond_block.inputs
3037 cond_tens = cond_block.tensors
3038 if ((while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape) or
3039 (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape)):
3040 error_result = True
3041
3042 info_dict = {
3043 "error_name": error_name,
3044 "error_result": error_result,
3045 "error_reason": error_reason,
3046 "param_reqs": param_reqs
3047 }
3048 return info_dict
3049
3050
3051 @staticmethod
3052 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
3053 error_name = ErrorIf.InputListBodyGraphInputMismatch
3054 param_reqs = {"rank": None, "dtype": None, "shape": None}
3055 error_result = False
3056 error_reason = "Input list does not match body graph input"
3057
3058 if check:
3059 basicBlocks = kwargs['basicBlocks']
3060 while_block = basicBlocks[0]
3061 while_inputs = while_block.inputs
3062 while_tens = while_block.tensors
3063 body_block = basicBlocks[2]
3064 body_outputs = body_block.inputs
3065 body_tens = body_block.tensors
3066 if ((while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape) or
3067 (while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape)):
3068 error_result = True
3069
3070 info_dict = {
3071 "error_name": error_name,
3072 "error_result": error_result,
3073 "error_reason": error_reason,
3074 "param_reqs": param_reqs
3075 }
3076 return info_dict
3077
3078
3079 @staticmethod
3080 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
3081 error_name = ErrorIf.InputListBodyGraphOutputMismatch
3082 param_reqs = {"rank": None, "dtype": None, "shape": None}
3083 error_result = False
3084 error_reason = "Input list does not match body graph output"
3085
3086 if check:
3087 basicBlocks = kwargs['basicBlocks']
3088 while_block = basicBlocks[0]
3089 while_inputs = while_block.inputs
3090 while_tens = while_block.tensors
3091 body_block = basicBlocks[2]
3092 body_outputs = body_block.outputs
3093 body_tens = body_block.tensors
3094 if ((while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape) or
3095 (while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape)):
3096 error_result = True
3097 info_dict = {
3098 "error_name": error_name,
3099 "error_result": error_result,
3100 "error_reason": error_reason,
3101 "param_reqs": param_reqs
3102 }
3103 return info_dict
3104
3105
3106 @staticmethod
3107 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
3108 error_name = ErrorIf.CondGraphOutputNotMatchingBool
3109 param_reqs = {"rank": None, "dtype": None, "shape": None}
3110 error_result = False
3111 error_reason = "Cond graph output is not a match list of booleans"
3112
3113 if check:
3114 basicBlocks = kwargs['basicBlocks']
3115 cond_block = basicBlocks[1]
3116 cond_outputs = cond_block.outputs
3117 cond_tens = cond_block.tensors
3118 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
3119 error_result = True
3120
3121 info_dict = {
3122 "error_name": error_name,
3123 "error_result": error_result,
3124 "error_reason": error_reason,
3125 "param_reqs": param_reqs
3126 }
3127 return info_dict
3128
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003129
Matthew Haddonb724efc2021-08-25 16:40:29 +01003130class TosaInvalidValidator:
3131
3132 @staticmethod
3133 def ivWrongDataTypeOrModeResize(**kwargs):
3134 input_dtype = kwargs["input_dtype"]
3135 args = kwargs["args"]
3136 mode = args[0]
3137 stride = args[1]
3138 stride_fp = args[4]
3139 output_dtype = args[8]
3140
3141 if mode == ResizeMode.BILINEAR:
3142 # Invalid output data type / Invalid input datatype
3143 return (
3144 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
3145 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
3146 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
3147 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
3148 )
3149 elif mode == ResizeMode.NEAREST:
3150 # Invalid output data type / Invalid input datatype
3151 return (
3152 (input_dtype != output_dtype) or
3153 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
3154 )
3155 else:
3156 # Invalid resize mode
3157 return True
3158
3159 @staticmethod
3160 def ivBadStride(**kwargs):
3161 input_dtype = kwargs["input_dtype"]
3162 args = kwargs["args"]
3163 stride_x = args[1][0]
3164 stride_y = args[1][1]
3165 stride_fp_x = args[4][0]
3166 stride_fp_y = args[4][1]
3167
3168 if input_dtype == DType.FLOAT:
3169 if stride_fp_x <= 0 or stride_fp_y <= 0:
3170 # Negative or zero stride
3171 return True
3172 else:
3173 if stride_x <= 0 or stride_y <= 0:
3174 # Negative or zero stride
3175 return True
3176 return False
3177
Matthew Haddonb724efc2021-08-25 16:40:29 +01003178 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00003179 def ivHeightWidthInvalid(**kwargs):
Matthew Haddonb724efc2021-08-25 16:40:29 +01003180 opName = kwargs['opName']
3181
3182 inputShapes = kwargs['shapeList']
Les Bell0e027d42021-11-09 14:42:14 +00003183 input_shape = inputShapes[0]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003184
3185 args = kwargs['args']
3186 strides = args[0]
3187 padding = args[1]
Les Bell0e027d42021-11-09 14:42:14 +00003188
Matthew Haddonb724efc2021-08-25 16:40:29 +01003189 if opName.endswith("pool2d"):
Les Bell0e027d42021-11-09 14:42:14 +00003190 # avg_pool2d, max_pool2d
3191 kernel_shape = args[2]
3192 h = (input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]) // strides[0]
3193 w = (input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]) // strides[1]
3194 # return True if any dimension is < 1
3195 return h < 1 or w < 1
Matthew Haddonb724efc2021-08-25 16:40:29 +01003196
Les Bell0e027d42021-11-09 14:42:14 +00003197 if opName.startswith("transpose_conv2d"):
3198 # transpose_conv2d
3199 dilations = args[2]
3200 output_shape = args[3]
3201 filter_shape = inputShapes[1]
3202 kernel_shape = filter_shape[1:-1]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003203
Les Bell0e027d42021-11-09 14:42:14 +00003204 def get_out_size(in_size, stride, kernel_size, dilation, out_pad, in_pad):
3205 """Calculate the transpose_conv2d output size for a dimension.
Matthew Haddonb724efc2021-08-25 16:40:29 +01003206
Les Bell0e027d42021-11-09 14:42:14 +00003207 Based on the keras function deconv_output_length, in
3208 https://github.com/keras-team/keras/blob/master/keras/utils/conv_utils.py
Matthew Haddonb724efc2021-08-25 16:40:29 +01003209
Les Bell0e027d42021-11-09 14:42:14 +00003210 Args:
3211 in_size: the input size - int
3212 stride: the stride - int
3213 kernel_size: the kernel size - int
3214 dilation: the kernel dilation - int
3215 out_pad: the output padding - int
3216 in_pad: the input padding - int
3217
3218 Returns:
3219 the output size
3220 """
3221 dilated_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
3222 return (in_size - 1) * stride + dilated_kernel_size - 2 * in_pad + out_pad
3223
3224 for pad_h, pad_w in (
3225 (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
3226 (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
3227 (0, 0) # VALID padding
3228 ):
3229 h = get_out_size(input_shape[1], strides[0], kernel_shape[0], dilations[0],
3230 padding[0], pad_h)
3231 w = get_out_size(input_shape[2], strides[1], kernel_shape[1], dilations[1],
3232 padding[1], pad_w)
3233 if output_shape[1] == h and output_shape[2] == w:
3234 return False
3235
3236 # output shape does not match the expected shape for any padding option
Matthew Haddonb724efc2021-08-25 16:40:29 +01003237 return True
Les Bell0e027d42021-11-09 14:42:14 +00003238
3239 if "conv2d" in opName or "conv3d" in opName:
3240 # conv2d, conv3d, depthwise_conv2d
3241 dilations = args[2]
3242 filter_shape = inputShapes[1]
3243 kernel_shape = filter_shape[0:2] if opName.startswith("depthwise_conv2d") else filter_shape[1:-1]
3244
3245 for i in range(len(kernel_shape)):
3246 dim = (
3247 input_shape[i + 1]
3248 - kernel_shape[i]
3249 - (kernel_shape[i] - 1) * (dilations[i] - 1)
3250 + padding[i * 2 + 0]
3251 + padding[i * 2 + 1]
3252 ) // strides[i] + 1
3253 # return True if any dimension is < 1
3254 if dim < 1:
3255 return True
3256 return False
3257
3258 assert False, f"Unrecognized Op: {opName}"
Matthew Haddonb724efc2021-08-25 16:40:29 +01003259
3260 @staticmethod
3261 def ivNonPositiveOutputShape(**kwargs):
3262 args = kwargs['args']
3263 output_shape = args[3]
3264 if output_shape[1] <= 0 or output_shape[2] <= 0:
3265 # Negative output shape
3266 return True
3267 return False
3268
3269
Eric Kunzee5e26762020-10-13 16:11:07 -07003270class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003271 # Maximum rank of tensor supported by test generator.
3272 TOSA_TENSOR_MAX_RANK = 6
3273
Eric Kunzee5e26762020-10-13 16:11:07 -07003274 def __init__(self, args):
3275 self.args = args
3276 self.basePath = args.output_dir
3277 self.random_seed = args.random_seed
3278 self.ser = None
3279 self.rng = np.random.default_rng(self.random_seed)
3280 self.createDynamicOpLists()
3281 self.initOpListDefaults()
3282 self.quantGen = TosaQuantGen()
3283 # Force makeShape to do a specific starting shape
3284 self.targetted_shape = None
3285
3286 def createSerializer(self, opName, testPath):
3287 self.testPath = os.path.join(opName, testPath)
3288
3289 fullPath = os.path.join(self.basePath, self.testPath)
3290 os.makedirs(fullPath, exist_ok=True)
3291 self.ser = ts.TosaSerializer(fullPath)
3292
3293 def getSerializer(self):
3294 return self.ser
3295
3296 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003297 with open(
3298 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
3299 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07003300 fd.write(self.ser.serialize())
3301
Kevin Cheng550ccc52021-03-03 11:21:43 -08003302 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
3303 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07003304
Matthew Haddon74567092021-07-16 15:38:20 +01003305 def resetRNG(self, seed=None):
3306 if seed == None:
3307 seed = self.random_seed + 1
3308 self.rng = np.random.default_rng(seed)
3309
Eric Kunzee5e26762020-10-13 16:11:07 -07003310 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07003311 if dtype == DType.BOOL:
3312 np_dt = np.bool
3313 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07003314 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07003315 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07003316 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003317 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01003318 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
3319 elif dtype == DType.UINT8:
3320 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003321 elif dtype == DType.INT16:
3322 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
3323 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003324 return np.int32(
3325 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
3326 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003327 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003328 return np.int64(
3329 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
3330 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003331 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01003332 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003333 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003334 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003335
Kevin Cheng989cb052021-04-28 16:29:44 -07003336 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07003337 placeholders = []
3338
Kevin Cheng989cb052021-04-28 16:29:44 -07003339 assert len(shape_list) == len(dtype_list)
3340
3341 for idx, shape in enumerate(shape_list):
3342 arr = self.getRandTensor(shape, dtype_list[idx])
3343 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07003344
3345 return placeholders
3346
Kevin Cheng989cb052021-04-28 16:29:44 -07003347 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07003348 consts = []
3349
Kevin Cheng989cb052021-04-28 16:29:44 -07003350 assert len(shape_list) == len(dtype_list)
3351
3352 for idx, shape in enumerate(shape_list):
3353 arr = self.getRandTensor(shape, dtype_list[idx])
3354 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07003355
3356 return consts
3357
3358 def makeShape(self, rank):
3359 if self.targetted_shape:
3360 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003361 return np.int32(
3362 self.rng.integers(
3363 low=self.args.tensor_shape_range[0],
3364 high=self.args.tensor_shape_range[1],
3365 size=rank,
3366 )
3367 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003368
3369 def setTargetShape(self, shape):
3370 self.targetted_shape = shape
3371
3372 def randInt(self, low=0, high=256):
3373 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
3374
3375 def getRandNumberDType(self, dtype):
3376 if dtype == DType.FLOAT:
3377 return self.rng.random()
3378 elif dtype == DType.BOOL:
3379 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07003380 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07003381 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07003382 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07003383 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01003384 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07003385 elif dtype == DType.INT16:
3386 low, high = (-32768, 32768)
3387 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003388 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07003389 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003390 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07003391 # Special size
3392 return np.int64(self.rng.integers(low, high, size=1))[0]
3393 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003394 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003395
3396 return np.int32(self.rng.integers(low, high, size=1))[0]
3397
3398 def shapeStr(self, shape):
3399
3400 sStr = []
3401 # Convert to strings
3402 for i in shape:
3403 sStr.append(str(i))
3404
Kevin Cheng550ccc52021-03-03 11:21:43 -08003405 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003406
3407 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07003408 if isinstance(t, list):
3409 assert len(t) >= 2
3410 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07003411 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07003412 if t == DType.BOOL:
3413 return "b"
3414 elif t == DType.INT4:
3415 return "i4"
3416 elif t == DType.INT8:
3417 return "i8"
3418 elif t == DType.UINT8:
3419 return "u8"
3420 elif t == DType.INT16:
3421 return "i16"
3422 elif t == DType.INT32:
3423 return "i32"
3424 elif t == DType.INT48:
3425 return "i48"
3426 elif t == DType.FLOAT:
3427 return "float"
3428 else:
3429 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07003430
3431 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003432 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08003433 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07003434 return 4
3435 elif t == DType.INT8:
3436 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08003437 elif t == DType.UINT8:
3438 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07003439 elif t == DType.INT16:
3440 return 16
3441 elif t == DType.INT32:
3442 return 32
3443 elif t == DType.INT48:
3444 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +01003445 elif t == DType.FLOAT:
3446 return 32
3447 elif t == DType.BOOL:
3448 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003449 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003450 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07003451
3452 # Argument generators
3453 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
3454 # Where the string descriptor is used to generate the test name and
3455 # The build_fcn_arg_list is expanded and passed to the operator test
3456 # build function
3457
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003458 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
3459 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
3460
Matthew Haddon848efb42021-09-09 12:30:53 +01003461 # build_placeholder returns an int, ABS/other ops does not
3462 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003463 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
3464 return result_tens
3465 elif op['op'] == Op.IDENTITY:
3466 self.ser.addOperator(op['op'], a.name, result_tens.name, None, qinfo)
3467 return result_tens
3468
3469 # Ensure new output type has correct qinfo
3470 if error_name == ErrorIf.WrongOutputType:
3471 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
3472 qinfo = ts.TosaSerializerQuantInfo()
3473 qinfo.UnaryQuantInfo(
3474 TosaQuantGen.getQinfo(self, a.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
3475 )
3476
3477 # Invalidate Input/Output list for error if checks.
3478 input_list = [a.name]
3479 output_list = [result_tens.name]
3480 pCount, cCount = op["operands"]
3481 num_operands = pCount + cCount
3482 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3483
3484 TosaErrorValidator.evValidateErrorIfs(
3485 self.ser,
3486 validator_fcns,
3487 error_name,
3488 op=op,
3489 input_dtype=a.dtype,
3490 output_dtype=result_tens.dtype,
3491 qinfo = qinfo,
3492 result_tensor = result_tens,
3493 input_list=input_list,
3494 output_list=output_list,
3495 num_operands=num_operands,
3496 )
3497
3498 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003499 return result_tens
3500
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003501 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
3502 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
3503
3504
3505 # Invalidate Input/Output list for error if checks.
3506 input_list = [a.name, b.name]
3507 output_list = [result_tens.name]
3508 pCount, cCount = op["operands"]
3509 num_operands = pCount + cCount
3510 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3511
3512 TosaErrorValidator.evValidateErrorIfs(
3513 self.ser,
3514 validator_fcns,
3515 error_name,
3516 op=op,
3517 input1 = a,
3518 input2 = b,
3519 input_dtype = a.dtype,
3520 output_dtype = result_tens.dtype,
3521 result_tensor = result_tens,
3522 input_list=input_list,
3523 output_list=output_list,
3524 num_operands=num_operands,
3525 )
3526
3527 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07003528 return result_tens
3529
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003530 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003531 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01003532 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003533 return result_tens
3534
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003535 def build_arithmetic_right_shift(self, op, a, b, round, validator_fcns=None, error_name=None):
3536 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
3537
3538 # Invalidate Input/Output list for error if checks.
3539 input_list = [a.name, b.name]
3540 output_list = [result_tens.name]
3541 pCount, cCount = op["operands"]
3542 num_operands = pCount + cCount
3543 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3544
3545 TosaErrorValidator.evValidateErrorIfs(
3546 self.ser,
3547 validator_fcns,
3548 error_name,
3549 op=op,
3550 input1 = a,
3551 input2 = b,
3552 input_dtype = a.dtype,
3553 output_dtype = result_tens.dtype,
3554 result_tensor = result_tens,
3555 input_list=input_list,
3556 output_list=output_list,
3557 num_operands=num_operands,
3558 )
Kevin Chengaee1fac2020-11-11 13:54:06 -08003559
3560 attr = ts.TosaSerializerAttribute()
3561 attr.ArithmeticRightShiftAttribute(round)
3562
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003563 self.ser.addOperator(op['op'], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08003564 return result_tens
3565
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003566 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
3567 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003568
3569 # Special for multiply:
3570 # Force the result to INT32 for INT types
3571 if a.dtype != DType.FLOAT:
3572 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003573 if error_name == ErrorIf.WrongOutputType:
3574 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
3575 outputDType = self.rng.choice(all_dtypes)
3576 result_tens.setDtype(outputDType)
3577
3578 # Invalidate Input/Output list for error if checks.
3579 input_list = [a.name, b.name]
3580 output_list = [result_tens.name]
3581 pCount, cCount = op["operands"]
3582 num_operands = pCount + cCount
3583 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3584
3585 TosaErrorValidator.evValidateErrorIfs(
3586 self.ser,
3587 validator_fcns,
3588 error_name,
3589 op=op,
3590 input1 = a,
3591 input2 = b,
3592 input_dtype = a.dtype,
3593 output_dtype = result_tens.dtype,
3594 result_tensor = result_tens,
3595 input_list=input_list,
3596 output_list=output_list,
3597 num_operands=num_operands,
3598 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003599
Kevin Chengaee1fac2020-11-11 13:54:06 -08003600 attr = ts.TosaSerializerAttribute()
3601 attr.MulAttribute(shift)
3602
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003603 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003604 return result_tens
3605
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003606 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
3607 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003608
Kevin Chengfe392ce2021-10-18 21:51:55 +00003609 attr = ts.TosaSerializerAttribute()
3610 attr.TableAttribute(table)
3611
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003612 # Invalidate Input/Output list for error if checks.
3613 input_list = [a.name]
3614 output_list = [result_tens.name]
3615 pCount, cCount = op["operands"]
3616 num_operands = pCount + cCount
3617 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3618
3619 TosaErrorValidator.evValidateErrorIfs(
3620 self.ser,
3621 validator_fcns,
3622 error_name,
3623 op=op,
3624 input_shape = a.shape,
3625 input_dtype = a.dtype,
3626 output_dtype = result_tens.dtype,
3627 result_tensor = result_tens,
3628 input_list=input_list,
3629 output_list=output_list,
3630 num_operands=num_operands,
3631 )
3632
3633 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003634
3635 return result_tens
3636
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003637 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
3638 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
3639
3640 # Invalidate Input/Output list for error if checks.
3641 input_list = [cond.name, a.name, b.name]
3642 output_list = [result_tens.name]
3643 pCount, cCount = op["operands"]
3644 num_operands = pCount + cCount
3645 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3646
3647 TosaErrorValidator.evValidateErrorIfs(
3648 self.ser,
3649 validator_fcns,
3650 error_name,
3651 op=op,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003652 input1 = cond,
3653 input2 = a,
3654 input3 = b,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003655 input_shape = a.shape,
3656 input_dtype = a.dtype,
3657 output_dtype = result_tens.dtype,
3658 result_tensor = result_tens,
3659 input_list=input_list,
3660 output_list=output_list,
3661 num_operands=num_operands,
3662 )
3663
3664 self.ser.addOperator(op['op'], input_list, output_list,)
Eric Kunzee5e26762020-10-13 16:11:07 -07003665 return result_tens
3666
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003667 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
3668 result_tens = OutputShaper.binaryComparisonOp(self.ser, self.rng, a, b, error_name)
3669
3670 # Invalidate Input/Output list for error if checks.
3671 input_list = [a.name, b.name]
3672 output_list = [result_tens.name]
3673 pCount, cCount = op["operands"]
3674 num_operands = pCount + cCount
3675 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3676
3677 TosaErrorValidator.evValidateErrorIfs(
3678 self.ser,
3679 validator_fcns,
3680 error_name,
3681 op=op,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003682 input1 = a,
3683 input2 = b,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003684 input_shape = a.shape,
3685 input_dtype = a.dtype,
3686 output_shape = result_tens.shape,
3687 output_dtype = result_tens.dtype,
3688 result_tensor = result_tens,
3689 input_list=input_list,
3690 output_list=output_list,
3691 num_operands=num_operands,
3692 )
3693
3694 self.ser.addOperator(op['op'], input_list, output_list,)
Eric Kunzee5e26762020-10-13 16:11:07 -07003695 return result_tens
3696
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003697 def build_argmax(self, op, a, axis, validator_fcns, error_name):
3698 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
3699
3700 # Invalidate Input/Output list for error if checks.
3701 input_list = [a.name]
3702 output_list = [result_tens.name]
3703 pCount, cCount = op["operands"]
3704 num_operands = pCount + cCount
3705 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3706
3707 TosaErrorValidator.evValidateErrorIfs(
3708 self.ser,
3709 validator_fcns,
3710 error_name,
3711 op=op,
3712 axis=axis,
3713 input_shape = a.shape,
3714 input_dtype = a.dtype,
3715 output_shape = result_tens.shape,
3716 output_dtype = result_tens.dtype,
3717 result_tensor = result_tens,
3718 input_list=input_list,
3719 output_list=output_list,
3720 num_operands=num_operands,
3721 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003722
3723 attr = ts.TosaSerializerAttribute()
3724 attr.AxisAttribute(axis)
3725
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003726 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003727 return result_tens
3728
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003729 def build_pool2d(self, op, input, stride, pad, kernel, validator_fcns=None, error_name=None, qinfo=None):
3730 result_tens = OutputShaper.pool2dOp(self.ser, self.rng, input, kernel, stride, pad, error_name)
3731
3732 # Ensure new output type has correct qinfo
3733 if error_name == ErrorIf.WrongInputType:
3734 if input.dtype not in [DType.INT8, DType.UINT8]:
3735 qinfo = ts.TosaSerializerQuantInfo()
3736 qinfo.UnaryQuantInfo(
Les Bell0e027d42021-11-09 14:42:14 +00003737 TosaQuantGen.getQinfo(self, input.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003738 )
3739
3740 # Invalidate Input/Output list for error if checks.
3741 input_list = [input.name]
3742 output_list = [result_tens.name]
3743 pCount, cCount = op["operands"]
3744 num_operands = pCount + cCount
3745 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3746
3747 TosaErrorValidator.evValidateErrorIfs(
3748 self.ser,
3749 validator_fcns,
3750 error_name,
3751 op=op,
3752 input_shape=input.shape,
3753 input_dtype=input.dtype,
3754 output_shape=result_tens.shape,
3755 output_dtype=result_tens.dtype,
3756 kernel=kernel,
3757 stride=stride,
3758 pad=pad,
3759 qinfo = qinfo,
3760 result_tensor = result_tens,
3761 input_list=input_list,
3762 output_list=output_list,
3763 num_operands=num_operands,
3764 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003765
3766 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003767 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07003768
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003769 self.ser.addOperator(op['op'], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003770 return result_tens
3771
Les Bell0e027d42021-11-09 14:42:14 +00003772 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003773 assert len(padding) == 4
3774 result_tens = OutputShaper.conv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +00003775 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
3776 )
3777
3778 # Ensure new output type has correct qinfo
3779 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
3780 qinfo = ts.TosaSerializerQuantInfo()
3781 qinfo.ConvQuantInfo(
3782 TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
3783 )
3784
3785 # Invalidate Input/Output list for error_if checks.
3786 input_list = [ifm.name, filter.name, bias.name]
3787 output_list = [result_tens.name]
3788 num_operands = sum(op["operands"])
3789 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3790
3791 TosaErrorValidator.evValidateErrorIfs(
3792 self.ser,
3793 validator_fcns,
3794 error_name,
3795 op=op,
3796 input_dtype=ifm.dtype,
3797 weight_dtype=filter.dtype,
3798 output_dtype=result_tens.dtype,
3799 qinfo=qinfo,
3800 input_list=input_list,
3801 num_operands=num_operands,
3802 output_list=output_list,
3803 pad=padding,
3804 stride=strides,
3805 dilation=dilations,
3806 input_shape=ifm.shape,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003807 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003808
3809 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003810 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07003811
Kevin Cheng550ccc52021-03-03 11:21:43 -08003812 self.ser.addOperator(
Les Bell0e027d42021-11-09 14:42:14 +00003813 op['op'], input_list, output_list, attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003814 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003815 return result_tens
3816
Les Bell0e027d42021-11-09 14:42:14 +00003817 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None):
Kevin Cheng1533b852021-09-01 12:51:58 -07003818 assert len(padding) == 6
3819 result_tens = OutputShaper.conv3dOp(
Les Bell0e027d42021-11-09 14:42:14 +00003820 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
3821 )
3822
3823 # Ensure new output type has correct qinfo
3824 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
3825 qinfo = ts.TosaSerializerQuantInfo()
3826 qinfo.ConvQuantInfo(
3827 TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
3828 )
3829
3830 # Invalidate Input/Output list for error_if checks.
3831 input_list = [ifm.name, filter.name, bias.name]
3832 output_list = [result_tens.name]
3833 num_operands = sum(op["operands"])
3834 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3835
3836 TosaErrorValidator.evValidateErrorIfs(
3837 self.ser,
3838 validator_fcns,
3839 error_name,
3840 op=op,
3841 input_dtype=ifm.dtype,
3842 weight_dtype=filter.dtype,
3843 output_dtype=result_tens.dtype,
3844 qinfo=qinfo,
3845 input_list=input_list,
3846 num_operands=num_operands,
3847 output_list=output_list,
3848 pad=padding,
3849 stride=strides,
3850 dilation=dilations,
3851 input_shape=ifm.shape,
Kevin Cheng1533b852021-09-01 12:51:58 -07003852 )
3853
3854 attr = ts.TosaSerializerAttribute()
3855 attr.ConvAttribute(padding, strides, dilations)
3856
3857 self.ser.addOperator(
Les Bell0e027d42021-11-09 14:42:14 +00003858 op['op'], input_list, output_list, attr, qinfo
Kevin Cheng1533b852021-09-01 12:51:58 -07003859 )
3860 return result_tens
3861
Kevin Cheng550ccc52021-03-03 11:21:43 -08003862 def build_transpose_conv2d(
Les Bell0e027d42021-11-09 14:42:14 +00003863 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, validator_fcns=None, error_name=None, qinfo=None
Kevin Cheng550ccc52021-03-03 11:21:43 -08003864 ):
3865 assert len(outpad) == 2
Les Bell0e027d42021-11-09 14:42:14 +00003866 result_tens = OutputShaper.transposeConv2DOp(self.ser, self.rng, ifm, output_shape, error_name)
3867
3868 # Ensure new output type has correct qinfo
3869 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
3870 qinfo = ts.TosaSerializerQuantInfo()
3871 qinfo.ConvQuantInfo(
3872 TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
3873 )
3874
3875 # Invalidate Input/Output list for error_if checks.
3876 input_list = [ifm.name, filter.name, bias.name]
3877 output_list = [result_tens.name]
3878 num_operands = sum(op["operands"])
3879 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3880
3881 TosaErrorValidator.evValidateErrorIfs(
3882 self.ser,
3883 validator_fcns,
3884 error_name,
3885 op=op,
3886 input_dtype=ifm.dtype,
3887 weight_dtype=filter.dtype,
3888 output_dtype=result_tens.dtype,
3889 qinfo=qinfo,
3890 input_list=input_list,
3891 num_operands=num_operands,
3892 output_list=output_list,
3893 pad=outpad,
3894 stride=stride,
3895 dilation=dilation,
3896 input_shape=ifm.shape,
3897 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003898
3899 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003900 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003901
Kevin Cheng550ccc52021-03-03 11:21:43 -08003902 self.ser.addOperator(
Les Bell0e027d42021-11-09 14:42:14 +00003903 op['op'], input_list, output_list, attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003904 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003905 return result_tens
3906
Kevin Cheng550ccc52021-03-03 11:21:43 -08003907 def build_depthwise_conv2d(
Les Bell0e027d42021-11-09 14:42:14 +00003908 self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None
Kevin Cheng550ccc52021-03-03 11:21:43 -08003909 ):
3910 result_tens = OutputShaper.depthwiseConv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +00003911 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
3912 )
3913
3914 # Ensure new output type has correct qinfo
3915 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
3916 qinfo = ts.TosaSerializerQuantInfo()
3917 qinfo.ConvQuantInfo(
3918 TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
3919 )
3920
3921 # Invalidate Input/Output list for error_if checks.
3922 input_list = [ifm.name, filter.name, bias.name]
3923 output_list = [result_tens.name]
3924 num_operands = sum(op["operands"])
3925 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3926
3927 TosaErrorValidator.evValidateErrorIfs(
3928 self.ser,
3929 validator_fcns,
3930 error_name,
3931 op=op,
3932 input_dtype=ifm.dtype,
3933 weight_dtype=filter.dtype,
3934 output_dtype=result_tens.dtype,
3935 qinfo=qinfo,
3936 input_list=input_list,
3937 num_operands=num_operands,
3938 output_list=output_list,
3939 pad=padding,
3940 stride=strides,
3941 dilation=dilations,
3942 input_shape=ifm.shape,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003943 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003944
3945 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003946 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07003947
Kevin Cheng550ccc52021-03-03 11:21:43 -08003948 self.ser.addOperator(
Les Bell0e027d42021-11-09 14:42:14 +00003949 op['op'], input_list, output_list, attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003950 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003951 return result_tens
3952
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003953 def build_fully_connected(self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None):
3954 result_tens = OutputShaper.fullyConnectedOp(self.ser, self.rng, ifm, filter, error_name)
3955
3956 # Invalidate Input/Output list for error if checks.
3957 input_list = [ifm.name, filter.name, bias.name]
3958 output_list = [result_tens.name]
3959 pCount, cCount = op["operands"]
3960 num_operands = pCount + cCount
3961 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3962
3963 TosaErrorValidator.evValidateErrorIfs(
3964 self.ser,
3965 validator_fcns,
3966 error_name,
3967 op=op,
3968 input_shape=ifm.shape,
3969 input_dtype=ifm.dtype,
3970 weight_dtype=filter.dtype,
3971 output_shape=result_tens.shape,
3972 output_dtype=result_tens.dtype,
3973 qinfo = qinfo,
3974 result_tensor = result_tens,
3975 input_list=input_list,
3976 output_list=output_list,
3977 num_operands=num_operands,
3978 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003979
Kevin Cheng550ccc52021-03-03 11:21:43 -08003980 self.ser.addOperator(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003981 op['op'], input_list, output_list, None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003982 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003983 return result_tens
3984
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003985 def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
3986 result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
3987
3988 # Invalidate Input/Output list for error if checks.
3989 input_list = [a.name, b.name]
3990 output_list = [result_tens.name]
3991 pCount, cCount = op["operands"]
3992 num_operands = pCount + cCount
3993 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3994
3995 TosaErrorValidator.evValidateErrorIfs(
3996 self.ser,
3997 validator_fcns,
3998 error_name,
3999 op=op,
4000 input_shape=a.shape,
4001 input_dtype=a.dtype,
4002 input2_shape=b.shape,
4003 input2_dtype=b.dtype,
4004 output_shape=result_tens.shape,
4005 output_dtype=result_tens.dtype,
4006 qinfo = qinfo,
4007 result_tensor = result_tens,
4008 input_list=input_list,
4009 output_list=output_list,
4010 num_operands=num_operands,
4011 )
4012
4013 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07004014 return result_tens
4015
Matthew Haddond6ce7252021-09-29 15:35:44 +01004016 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
4017 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
4018
4019 # Invalidate Input/Output list for error if checks.
4020 input_list = [a.name]
4021 output_list = [result_tens.name]
4022 pCount, cCount = op["operands"]
4023 num_operands = pCount + cCount
4024 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4025
4026 TosaErrorValidator.evValidateErrorIfs(
4027 self.ser,
4028 validator_fcns,
4029 error_name,
4030 op=op,
4031 axis = axis,
4032 input_shape = a.shape,
4033 output_shape = result_tens.shape,
4034 input_dtype = a.dtype,
4035 output_dtype = result_tens.dtype,
4036 result_tensor = result_tens,
4037 input_list=input_list,
4038 output_list=output_list,
4039 num_operands=num_operands,
4040 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004041
4042 attr = ts.TosaSerializerAttribute()
4043 attr.AxisAttribute(axis)
4044
Matthew Haddond6ce7252021-09-29 15:35:44 +01004045 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004046 return result_tens
4047
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004048 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
4049 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004050
Jeremy Johnson18e26662021-07-22 16:15:29 +01004051 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07004052
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004053 if error_name == ErrorIf.MaxSmallerMin:
4054 # Make sure the numbers are different to invoke this error
4055 while v[0] == v[1]:
4056 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
4057 max_val = min(v)
4058 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07004059 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004060 max_val = max(v)
4061 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07004062
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004063 # Invalidate Input/Output list for error if checks.
4064 input_list = [a.name]
4065 output_list = [result_tens.name]
4066 pCount, cCount = op["operands"]
4067 num_operands = pCount + cCount
4068 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4069
4070 TosaErrorValidator.evValidateErrorIfs(
4071 self.ser,
4072 validator_fcns,
4073 error_name,
4074 op=op,
4075 max_val=max_val,
4076 min_val=min_val,
4077 input_shape = a.shape,
4078 output_shape = result_tens.shape,
4079 input_dtype = a.dtype,
4080 output_dtype = result_tens.dtype,
4081 result_tensor = result_tens,
4082 input_list=input_list,
4083 output_list=output_list,
4084 num_operands=num_operands,
4085 )
4086
4087 attr = ts.TosaSerializerAttribute()
4088 if a.dtype == DType.FLOAT:
4089 attr.ClampAttribute(0, 0, min_val, max_val)
4090 else:
4091 attr.ClampAttribute(min_val, max_val, 0, 0)
4092
4093 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004094 return result_tens
4095
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004096 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
4097 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004098 attr = ts.TosaSerializerAttribute()
4099
4100 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
4101
Matthew Haddon848efb42021-09-09 12:30:53 +01004102 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004103 return result_tens
4104
4105 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004106 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
4107 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004108
Matthew Haddon848efb42021-09-09 12:30:53 +01004109 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07004110 return result_tens
4111
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004112 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
4113 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
4114
4115 # Invalidate Input/Output list for error if checks.
4116 input_list = [a.name]
4117 output_list = [result_tens.name]
4118 pCount, cCount = op["operands"]
4119 num_operands = pCount + cCount
4120 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4121
4122 TosaErrorValidator.evValidateErrorIfs(
4123 self.ser,
4124 validator_fcns,
4125 error_name,
4126 op=op,
4127 input_shape = a.shape,
4128 output_shape = result_tens.shape,
4129 input_dtype = a.dtype,
4130 output_dtype = result_tens.dtype,
4131 result_tensor = result_tens,
4132 input_list=input_list,
4133 output_list=output_list,
4134 num_operands=num_operands,
4135 )
4136
4137 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004138 return result_tens
4139
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004140 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
4141 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
4142
4143 # Invalidate Input/Output list for error if checks.
4144 input_list = [a.name]
4145 output_list = [result_tens.name]
4146 pCount, cCount = op["operands"]
4147 num_operands = pCount + cCount
4148 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4149
4150 TosaErrorValidator.evValidateErrorIfs(
4151 self.ser,
4152 validator_fcns,
4153 error_name,
4154 op=op,
4155 input_shape = a.shape,
4156 output_shape = result_tens.shape,
4157 input_dtype = a.dtype,
4158 output_dtype = result_tens.dtype,
4159 result_tensor = result_tens,
4160 input_list=input_list,
4161 output_list=output_list,
4162 num_operands=num_operands,
4163 )
4164
4165 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004166 return result_tens
4167
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004168 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
4169 if error_name != ErrorIf.WrongInputType:
4170 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01004171
4172 # To store variable length list of input tensors we need to store axis along with it
4173 axis = a[-1]
4174 a = a[:-1]
4175
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004176 result_tens = OutputShaper.concatOp(self.ser, self.rng, axis, *a, error_name=error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004177
Matthew Haddon818ab902021-07-27 09:12:49 +01004178 input_tensor_names = []
4179 for tensor in a:
4180 input_tensor_names.append(tensor.name)
4181
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004182 # Invalidate Input/Output list for error if checks.
4183 input_list = input_tensor_names
4184 output_list = [result_tens.name]
4185 pCount, cCount = op["operands"]
4186 num_operands = pCount + cCount
4187 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4188
4189 TosaErrorValidator.evValidateErrorIfs(
4190 self.ser,
4191 validator_fcns,
4192 error_name,
4193 op=op,
4194 axis=axis,
4195 input_shape = a[0].shape,
4196 output_shape = result_tens.shape,
4197 input_dtype = a[0].dtype,
4198 output_dtype = result_tens.dtype,
4199 inputs=a,
4200 result_tensor = result_tens,
4201 input_list=input_list,
4202 output_list=output_list,
4203 num_operands=num_operands,
4204 )
4205
4206 attr = ts.TosaSerializerAttribute()
4207 attr.AxisAttribute(axis)
4208
4209
4210 self.ser.addOperator(op['op'], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01004211 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07004212
Kevin Chengfe392ce2021-10-18 21:51:55 +00004213 def build_pad(self, op, a, padding, pad_const_int, pad_const_float, validator_fcns=None, error_name=None, qinfo=None):
Matthew Haddone807aae2021-10-11 18:12:58 +01004214 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004215
Kevin Chengfe392ce2021-10-18 21:51:55 +00004216 attr = ts.TosaSerializerAttribute()
4217 attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07004218
Matthew Haddone807aae2021-10-11 18:12:58 +01004219 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00004220 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01004221 output_list = [result_tens.name]
4222 pCount, cCount = op["operands"]
4223 num_operands = pCount + cCount
4224 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4225
4226 TosaErrorValidator.evValidateErrorIfs(
4227 self.ser,
4228 validator_fcns,
4229 error_name,
4230 op=op,
4231 input_shape = a.shape,
4232 output_shape = result_tens.shape,
4233 input_dtype = a.dtype,
4234 output_dtype = result_tens.dtype,
4235 pad=padding,
4236 qinfo=qinfo,
4237 result_tensor = result_tens,
4238 input_list=input_list,
4239 output_list=output_list,
4240 num_operands=num_operands,
4241 )
4242
Kevin Cheng550ccc52021-03-03 11:21:43 -08004243 self.ser.addOperator(
Kevin Chengfe392ce2021-10-18 21:51:55 +00004244 op['op'], input_list, output_list, attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08004245 )
Matthew Haddone86fd342021-09-07 16:12:21 +01004246 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07004247
Matthew Haddone807aae2021-10-11 18:12:58 +01004248 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
4249 result_tens = OutputShaper.reshapeOp(self.ser, self.rng, a, newShape, error_name)
4250
4251 # Invalidate Input/Output list for error if checks.
4252 input_list = [a.name]
4253 output_list = [result_tens.name]
4254 pCount, cCount = op["operands"]
4255 num_operands = pCount + cCount
4256 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4257
4258 TosaErrorValidator.evValidateErrorIfs(
4259 self.ser,
4260 validator_fcns,
4261 error_name,
4262 op=op,
4263 input_shape = a.shape,
4264 output_shape = result_tens.shape,
4265 input_dtype = a.dtype,
4266 output_dtype = result_tens.dtype,
4267 result_tensor = result_tens,
4268 input_list=input_list,
4269 output_list=output_list,
4270 num_operands=num_operands,
4271 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004272
4273 attr = ts.TosaSerializerAttribute()
4274 attr.ReshapeAttribute(newShape)
4275
Matthew Haddone807aae2021-10-11 18:12:58 +01004276 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004277 return result_tens
4278
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004279 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
4280 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
4281
4282 # Invalidate Input/Output list for error if checks.
4283 input_list = [a.name]
4284 output_list = [result_tens.name]
4285 pCount, cCount = op["operands"]
4286 num_operands = pCount + cCount
4287 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4288
4289 TosaErrorValidator.evValidateErrorIfs(
4290 self.ser,
4291 validator_fcns,
4292 error_name,
4293 op=op,
4294 axis=axis,
4295 input_shape = a.shape,
4296 output_shape = result_tens.shape,
4297 input_dtype = a.dtype,
4298 output_dtype = result_tens.dtype,
4299 result_tensor = result_tens,
4300 input_list=input_list,
4301 output_list=output_list,
4302 num_operands=num_operands,
4303 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004304
4305 attr = ts.TosaSerializerAttribute()
4306 attr.AxisAttribute(axis)
4307
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004308 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004309 return result_tens
4310
Matthew Haddone807aae2021-10-11 18:12:58 +01004311 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
4312 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004313
Kevin Chengfe392ce2021-10-18 21:51:55 +00004314 attr = ts.TosaSerializerAttribute()
4315 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07004316
Matthew Haddone807aae2021-10-11 18:12:58 +01004317 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00004318 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01004319 output_list = [result_tens.name]
4320 pCount, cCount = op["operands"]
4321 num_operands = pCount + cCount
4322 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4323
4324 TosaErrorValidator.evValidateErrorIfs(
4325 self.ser,
4326 validator_fcns,
4327 error_name,
4328 op=op,
4329 input_shape = a.shape,
4330 output_shape = result_tens.shape,
4331 perms=perms,
4332 input_dtype = a.dtype,
4333 output_dtype = result_tens.dtype,
4334 result_tensor = result_tens,
4335 input_list=input_list,
4336 output_list=output_list,
4337 num_operands=num_operands,
4338 )
4339
4340
Kevin Chengfe392ce2021-10-18 21:51:55 +00004341 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004342 return result_tens
4343
Matthew Haddone807aae2021-10-11 18:12:58 +01004344 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
4345 result_tens = OutputShaper.sliceOp(self.ser, self.rng, a, start, size, error_name)
4346
4347 # Invalidate Input/Output list for error if checks.
4348 input_list = [a.name]
4349 output_list = [result_tens.name]
4350 pCount, cCount = op["operands"]
4351 num_operands = pCount + cCount
4352 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4353
4354 TosaErrorValidator.evValidateErrorIfs(
4355 self.ser,
4356 validator_fcns,
4357 error_name,
4358 op=op,
4359 input_shape = a.shape,
4360 output_shape = result_tens.shape,
4361 input_dtype = a.dtype,
4362 output_dtype = result_tens.dtype,
4363 start=start,
4364 size=size,
4365 result_tensor = result_tens,
4366 input_list=input_list,
4367 output_list=output_list,
4368 num_operands=num_operands,
4369 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004370
4371 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01004372 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07004373
Matthew Haddone807aae2021-10-11 18:12:58 +01004374 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004375 return result_tens
4376
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004377 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
4378 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
4379
4380 # Invalidate Input/Output list for error if checks.
4381 input_list = [a.name]
4382 output_list = [result_tens.name]
4383 pCount, cCount = op["operands"]
4384 num_operands = pCount + cCount
4385 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4386
4387 TosaErrorValidator.evValidateErrorIfs(
4388 self.ser,
4389 validator_fcns,
4390 error_name,
4391 op=op,
4392 input_shape = a.shape,
4393 output_shape = result_tens.shape,
4394 input_dtype = a.dtype,
4395 output_dtype = result_tens.dtype,
4396 result_tensor = result_tens,
4397 input_list=input_list,
4398 output_list=output_list,
4399 num_operands=num_operands,
4400 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004401
4402 attr = ts.TosaSerializerAttribute()
4403 attr.TileAttribute(multiples)
4404
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004405 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004406 return result_tens
4407
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004408 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004409
4410 # Create a new indicies tensor
4411 # here with data that doesn't exceed the dimensions of the values tensor
4412
Kevin Cheng550ccc52021-03-03 11:21:43 -08004413 K = values.shape[1] # K
4414 W = self.randInt(
4415 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
4416 ) # W
4417 indicies_arr = np.int32(
4418 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
4419 ) # (N, W)
4420 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004421
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004422 result_tens = OutputShaper.gatherOp(self.ser, self.rng, values, indicies, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004423
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004424 # Invalidate Input/Output list for error if checks.
4425 input_list = [values.name, indicies.name]
4426 output_list = [result_tens.name]
4427 pCount, cCount = op["operands"]
4428 num_operands = pCount + cCount
4429 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4430
4431 TosaErrorValidator.evValidateErrorIfs(
4432 self.ser,
4433 validator_fcns,
4434 error_name,
4435 op=op,
4436 input_shape = values.shape,
4437 output_shape = result_tens.shape,
4438 input_dtype = values.dtype,
4439 output_dtype = result_tens.dtype,
4440 result_tensor = result_tens,
4441 input_list=input_list,
4442 output_list=output_list,
4443 num_operands=num_operands,
4444 )
4445
4446 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004447
4448 return result_tens
4449
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004450 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08004451
4452 # Create a new indicies tensor
4453 # here with data that doesn't exceed the dimensions of the values_in tensor
4454
Kevin Cheng550ccc52021-03-03 11:21:43 -08004455 K = values_in.shape[1] # K
4456 W = input.shape[1] # W
4457 indicies_arr = np.int32(
4458 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
4459 ) # (N, W)
4460 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004461
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004462 result_tens = OutputShaper.scatterOp(self.ser, self.rng, values_in, indicies, input, error_name)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004463
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004464 # Invalidate Input/Output list for error if checks.
4465 input_list = [values_in.name, indicies.name, input.name]
4466 output_list = [result_tens.name]
4467 pCount, cCount = op["operands"]
4468 num_operands = pCount + cCount
4469 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4470
4471 TosaErrorValidator.evValidateErrorIfs(
4472 self.ser,
4473 validator_fcns,
4474 error_name,
4475 op=op,
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004476 input_shape = values_in.shape,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004477 output_shape = result_tens.shape,
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004478 input_dtype = values_in.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004479 output_dtype = result_tens.dtype,
4480 result_tensor = result_tens,
4481 input_list=input_list,
4482 output_list=output_list,
4483 num_operands=num_operands,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004484 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08004485
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004486 self.ser.addOperator(op['op'], input_list, output_list)
4487
Kevin Cheng77d0f762020-11-24 10:26:32 -08004488 return result_tens
4489
Matthew Haddon848efb42021-09-09 12:30:53 +01004490
Kevin Cheng550ccc52021-03-03 11:21:43 -08004491 def build_resize(
4492 self,
4493 op,
4494 input,
4495 mode,
4496 stride,
4497 offset,
4498 shift,
4499 stride_fp,
4500 offset_fp,
4501 output_dims,
4502 input_dtype,
4503 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01004504 validator_fcns,
4505 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004506 ):
4507 result_tens = OutputShaper.resizeOp(
4508 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004509 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004510 input,
4511 mode,
4512 stride,
4513 offset,
4514 shift,
4515 stride_fp,
4516 offset_fp,
4517 output_dims,
4518 input_dtype,
4519 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01004520 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08004521 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004522
Matthew Haddon848efb42021-09-09 12:30:53 +01004523 # Invalidate Input/Output list for error if checks.
4524 input_list = [input.name]
4525 output_list = [result_tens.name]
4526 pCount, cCount = op["operands"]
4527 num_operands = pCount + cCount
4528 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
Matthew Haddone86fd342021-09-07 16:12:21 +01004529
Matthew Haddon848efb42021-09-09 12:30:53 +01004530 TosaErrorValidator.evValidateErrorIfs(
4531 self.ser,
4532 validator_fcns,
4533 error_name,
4534 op=op,
4535 mode=mode,
4536 shift=shift,
4537 input_dtype=input_dtype,
4538 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004539 input_shape=input.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01004540 output_shape=output_dims,
4541 offset=offset,
4542 offset_fp=offset_fp,
4543 stride=stride,
4544 stride_fp=stride_fp,
4545 input_list=input_list,
4546 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004547 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01004548 num_operands=num_operands,
4549 )
Matthew Haddone86fd342021-09-07 16:12:21 +01004550
Eric Kunzee5e26762020-10-13 16:11:07 -07004551 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08004552
Kevin Cheng550ccc52021-03-03 11:21:43 -08004553 attr.ResizeAttribute(
4554 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
4555 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004556
Matthew Haddon848efb42021-09-09 12:30:53 +01004557 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004558 return result_tens
4559
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004560 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
4561 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
4562 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004563 self.ser.addOperator(
4564 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
4565 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004566 return result_tens
4567
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004568 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07004569 self.ser.addOutputTensor(val)
4570 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07004571
4572 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004573 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
4574 result_tens = OutputShaper.typeConversionOp(self.ser, self.rng, val, out_dtype, error_name)
4575
4576 # Invalidate Input/Output list for error if checks.
4577 input_list = [val.name]
4578 output_list = [result_tens.name]
4579 pCount, cCount = op["operands"]
4580 num_operands = pCount + cCount
4581 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4582
4583 TosaErrorValidator.evValidateErrorIfs(
4584 self.ser,
4585 validator_fcns,
4586 error_name,
4587 op=op,
4588 input_shape = val.shape,
4589 output_shape = result_tens.shape,
4590 input_dtype = val.dtype,
4591 output_dtype = result_tens.dtype,
4592 result_tensor = result_tens,
4593 input_list=input_list,
4594 output_list=output_list,
4595 num_operands=num_operands,
4596 )
4597
4598 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004599 return result_tens
4600
Matthew Haddonc2025212021-10-08 21:21:05 +01004601 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel, validator_fcns, error_name):
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004602 result_tens = OutputShaper.typeConversionOp(self.ser, self.rng, val, out_dtype, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004603
4604 if per_channel:
4605 nc = val.shape[-1]
4606 else:
4607 nc = 1
4608
4609 in_type_width = self.typeWidth(val.dtype)
4610 out_type_width = self.typeWidth(out_dtype)
4611
Kevin Cheng3a478572021-01-22 17:21:02 -08004612 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004613 input_zp = self.randInt(-128, 128)
4614 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07004615 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004616 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07004617 in_type_width = in_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01004618 elif error_name == ErrorIf.InputZeroPointNotZero:
4619 input_zp = self.randInt(-128, 128)
4620 if input_zp == 0:
4621 input_zp = input_zp + self.rng.integers(1, 10)
4622 in_type_width = in_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004623 else:
4624 input_zp = 0
4625
Kevin Cheng3a478572021-01-22 17:21:02 -08004626 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004627 output_zp = self.randInt(-128, 128)
4628 out_type_width = out_type_width + 1
4629 elif out_dtype == DType.UINT8:
4630 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07004631 out_type_width = out_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01004632 elif error_name == ErrorIf.OutputZeroPointNotZero:
4633 output_zp = self.randInt(-128, 128)
4634 if output_zp == 0:
4635 output_zp = output_zp + self.rng.integers(1, 10)
4636 out_type_width = out_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004637 else:
4638 output_zp = 0
4639
4640 # Calculate scale based on:
4641 # scale = a *(2^output_width)/(2^input_width))
4642
4643 a = np.float32(self.rng.random(size=[nc]))
4644 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
4645
4646 if scale32:
4647 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01004648 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07004649 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
4650 else:
4651 # Cap the scaling at 2^15 - 1 for scale16
4652 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
4653
Kevin Cheng550ccc52021-03-03 11:21:43 -08004654 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07004655
4656 multiplier_arr = np.int32(np.zeros(shape=[nc]))
4657 shift_arr = np.int32(np.zeros(shape=[nc]))
4658
4659 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004660 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
4661 scale_arr[i], scale32
4662 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004663
Kevin Cheng550ccc52021-03-03 11:21:43 -08004664 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07004665
Matthew Haddonc2025212021-10-08 21:21:05 +01004666 # Invalidate Input/Output list for error if checks.
4667 input_list = [val.name]
4668 output_list = [result_tens.name]
4669 pCount, cCount = op["operands"]
4670 num_operands = pCount + cCount
4671 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4672
4673 qinfo = (input_zp, output_zp)
4674 TosaErrorValidator.evValidateErrorIfs(
4675 self.ser,
4676 validator_fcns,
4677 error_name,
4678 op=op,
4679 input_dtype=val.dtype,
4680 output_dtype=out_dtype,
4681 input_shape=val.shape,
4682 qinfo=qinfo,
4683 scale32 = scale32,
4684 double_round = double_round,
4685 input_list=input_list,
4686 output_list=output_list,
4687 result_tensor=result_tens,
4688 num_operands=num_operands,
4689 )
4690
Eric Kunzee5e26762020-10-13 16:11:07 -07004691 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004692 attr.RescaleAttribute(
4693 input_zp,
4694 output_zp,
4695 multiplier_arr,
4696 shift_arr,
4697 scale32,
4698 double_round,
4699 per_channel,
4700 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004701
Matthew Haddonc2025212021-10-08 21:21:05 +01004702 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004703 return result_tens
4704
Matthew Haddon630c17c2021-10-14 15:05:41 +01004705 def build_cond_if_const(self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004706 # For cond_if with constants, we're supplied with then/else tensors that we ignore
4707 # (except for the generated shap) and the condition. Build Then/Else blocks
4708 # and fill them with const nodes for the body.
4709
4710 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08004711 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07004712
4713 # Make then/else tensors
4714 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01004715
4716 # Create an incorrect output shape for error_if tests
4717 if error_name in [ErrorIf.CondIfOutputListThenGraphMismatch, ErrorIf.CondIfOutputListElseGraphMismatch]:
4718 incorrect_shape = deepcopy(then_tens.shape)
4719 for i in range(len(incorrect_shape)):
4720 incorrect_shape[i] = incorrect_shape[i] + self.rng.choice([-3, -2, 2, 3])
4721 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
4722
Jeremy Johnson18e26662021-07-22 16:15:29 +01004723 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
4724 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07004725
4726 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08004727 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07004728
4729 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08004730 then_block = "THEN_BLOCK"
4731 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07004732 attr = ts.TosaSerializerAttribute()
4733 attr.CondIfAttribute(then_block, else_block)
4734
4735 # Finally, build the op and the two blocks
Matthew Haddon848efb42021-09-09 12:30:53 +01004736 self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004737
4738 self.ser.startBasicBlock(then_block)
4739 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01004740 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
4741 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
4742 else:
4743 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004744 self.ser.addOutputTensor(then_tens)
4745
4746 self.ser.startBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004747 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
4748 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
4749 else:
4750 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004751 self.ser.addOutputTensor(else_tens)
4752
Matthew Haddon630c17c2021-10-14 15:05:41 +01004753 TosaErrorValidator.evValidateErrorIfs(
4754 self.ser,
4755 validator_fcns,
4756 error_name,
4757 op=op,
4758 basicBlocks=self.ser.basicBlocks
4759 )
4760
Eric Kunzee5e26762020-10-13 16:11:07 -07004761 return result_tens
4762
Matthew Haddon630c17c2021-10-14 15:05:41 +01004763 def build_cond_if_binary(self, op, a, b, cond, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004764 # For cond_if with a binary op in the then/else blocks, take a and b and
4765 # alternately add or subtract them based on the condition
4766
4767 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08004768 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07004769
Kevin Cheng550ccc52021-03-03 11:21:43 -08004770 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004771
4772 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08004773 then_block = "THEN_BLOCK"
4774 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07004775 attr = ts.TosaSerializerAttribute()
4776 attr.CondIfAttribute(then_block, else_block)
4777
Matthew Haddon630c17c2021-10-14 15:05:41 +01004778 if error_name in [ErrorIf.CondIfInputListThenGraphMismatch, ErrorIf.CondIfInputListElseGraphMismatch,
4779 ErrorIf.CondIfOutputListElseGraphMismatch, ErrorIf.CondIfOutputListThenGraphMismatch]:
4780 incorrect_shape = a.shape.copy()
4781 for i in range(len(incorrect_shape)):
4782 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
4783 incorrect_block_input = deepcopy(a)
4784 incorrect_block_input.shape = incorrect_shape
4785
4786
Eric Kunzee5e26762020-10-13 16:11:07 -07004787 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08004788 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01004789 op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08004790 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004791
Les Bell6040b4d2021-10-11 12:50:31 +01004792 if a.dtype in (DType.FLOAT, DType.INT32):
4793 then_op, else_op = Op.ADD, Op.SUB
4794 elif a.dtype in (DType.INT8, DType.INT16):
4795 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
4796 else:
4797 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07004798
Les Bell6040b4d2021-10-11 12:50:31 +01004799 for block, op in ((then_block, then_op), (else_block, else_op)):
4800 self.ser.startBasicBlock(block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004801 if ((error_name == ErrorIf.CondIfInputListThenGraphMismatch and block == then_block) or
4802 (error_name == ErrorIf.CondIfInputListElseGraphMismatch and block == else_block)):
4803 self.ser.addInputTensor(incorrect_block_input)
4804 self.ser.addInputTensor(b)
4805 tens = self.ser.addOutput(a.shape, a.dtype)
4806 elif ((error_name == ErrorIf.CondIfOutputListThenGraphMismatch and block == then_block) or
4807 (error_name == ErrorIf.CondIfOutputListElseGraphMismatch and block == else_block)):
4808 self.ser.addInputTensor(a)
4809 self.ser.addInputTensor(b)
4810 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
4811 else:
4812 self.ser.addInputTensor(a)
4813 self.ser.addInputTensor(b)
4814 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01004815 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07004816
Matthew Haddon630c17c2021-10-14 15:05:41 +01004817 TosaErrorValidator.evValidateErrorIfs(
4818 self.ser,
4819 validator_fcns,
4820 error_name,
4821 op=op,
4822 a=a,
4823 b=b,
4824 basicBlocks=self.ser.basicBlocks
4825 )
4826
Eric Kunzee5e26762020-10-13 16:11:07 -07004827 return result_tens
4828
Matthew Haddon630c17c2021-10-14 15:05:41 +01004829 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004830 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07004831
Kevin Cheng550ccc52021-03-03 11:21:43 -08004832 cond_block = "COND_BLOCK"
4833 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07004834
4835 attr = ts.TosaSerializerAttribute()
4836 attr.WhileLoopAttribute(cond_block, body_block)
4837
4838 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08004839 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004840 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08004841 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07004842
4843 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004844 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
4845 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004846 if error_name == ErrorIf.InputListOutputListMismatch:
4847 incorrect_acc = deepcopy(acc)
4848 for i in range(len(incorrect_acc.shape)):
4849 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
4850 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
4851 else:
4852 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004853
4854 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08004855 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01004856 op['op'],
Kevin Cheng550ccc52021-03-03 11:21:43 -08004857 [iter.name, a.name, acc.name],
4858 [iter_out.name, a_out.name, acc_out.name],
4859 attr,
4860 )
Kevin Chengb227ae52021-09-02 13:43:17 -07004861 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07004862
Matthew Haddon630c17c2021-10-14 15:05:41 +01004863 if error_name in [ErrorIf.InputListCondGraphMismatch, ErrorIf.InputListBodyGraphInputMismatch, ErrorIf.InputListBodyGraphOutputMismatch]:
4864 incorrect_iter = deepcopy(iter)
4865 for i in range(len(incorrect_iter.shape)):
4866 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
4867 if len(incorrect_iter.shape) == 0:
4868 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
4869
4870 incorrect_acc = deepcopy(acc)
4871 for i in range(len(incorrect_acc.shape)):
4872 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
4873
Eric Kunzee5e26762020-10-13 16:11:07 -07004874 # COND block (input: iter, output: cond_tens )
4875 self.ser.startBasicBlock(cond_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004876 if error_name == ErrorIf.InputListCondGraphMismatch:
4877 self.ser.addInputTensor(incorrect_iter)
4878 self.ser.addInputTensor(a)
4879 self.ser.addInputTensor(incorrect_acc)
4880 else:
4881 self.ser.addInputTensor(iter)
4882 self.ser.addInputTensor(a)
4883 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004884 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01004885
4886 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
4887 cond_tens = self.ser.addOutput([], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT]))
4888 else:
4889 cond_tens = self.ser.addOutput([], DType.BOOL)
4890
Kevin Cheng550ccc52021-03-03 11:21:43 -08004891 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07004892
4893 # BODY block (input: a, acc, iter, output: a, acc, iter)
4894 # Note that local intermediate tensors need to be declared here for the outputs
4895 self.ser.startBasicBlock(body_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004896 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
4897 self.ser.addInputTensor(incorrect_iter)
4898 self.ser.addInputTensor(a)
4899 self.ser.addInputTensor(incorrect_acc)
4900 else:
4901 self.ser.addInputTensor(iter)
4902 self.ser.addInputTensor(a)
4903 self.ser.addInputTensor(acc)
4904
Kevin Cheng550ccc52021-03-03 11:21:43 -08004905 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01004906
4907 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
4908 iter_body_out = self.ser.addIntermediate(incorrect_iter.shape, incorrect_iter.dtype)
4909 acc_body_out = self.ser.addIntermediate(incorrect_acc.shape, incorrect_acc.dtype)
4910 else:
4911 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
4912 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
4913
Eric Kunzee5e26762020-10-13 16:11:07 -07004914 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
4915 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
4916 self.ser.addOutputTensor(iter_body_out)
4917 self.ser.addOutputTensor(a)
4918 self.ser.addOutputTensor(acc_body_out)
4919
Matthew Haddon630c17c2021-10-14 15:05:41 +01004920 TosaErrorValidator.evValidateErrorIfs(
4921 self.ser,
4922 validator_fcns,
4923 error_name,
4924 op=op,
4925 basicBlocks=self.ser.basicBlocks
4926 )
4927
Eric Kunzee5e26762020-10-13 16:11:07 -07004928 return acc_out
4929
Matthew Haddon1c00b712021-10-01 15:51:03 +01004930 def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
4931 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
4932 default_test_rank_range = range(1, 5)
4933 if not shapeFilter:
4934 shapeFilter = [None]
4935
4936 # Calculate the filters based on what is requested and what the operator allows
4937 rmin, rmax = op["rank"]
4938 if rankFilter is not None:
4939 cleanRankFilter = []
4940 # Ensure rankFilter values are allowed by operator
4941 for rank in rankFilter:
4942 if rank >= rmin and rank <= rmax:
4943 cleanRankFilter.append(rank)
4944 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01004945 # Ensure default behaviour is bounded by default range or by operator,
4946 # whichever is the smaller range of ranks.
4947 opRankRange = range(rmin, rmax + 1)
4948 cleanRankFilter = opRankRange if len(opRankRange) <= len(default_test_rank_range) else default_test_rank_range
Matthew Haddon1c00b712021-10-01 15:51:03 +01004949 else:
4950 cleanRankFilter = range(rmin, rmax + 1)
4951
4952 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004953
Matthew Haddon1c00b712021-10-01 15:51:03 +01004954 if dtypeFilter is not None:
4955 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01004956 # Create list of operator dtypes filtered by requested dtypes
4957 for dtype in dtypes:
4958 if dtype in dtypeFilter or (isinstance(dtype, list) and dtype[0] in dtypeFilter):
Matthew Haddon1c00b712021-10-01 15:51:03 +01004959 cleanDtypeFilter.append(dtype)
4960 else:
4961 cleanDtypeFilter = dtypes
4962
4963 if testType == 'positive':
4964 filterDict = {
4965 'shapeFilter': shapeFilter,
4966 'rankFilter': cleanRankFilter,
4967 'dtypeFilter': cleanDtypeFilter
4968 }
4969 return filterDict
4970 elif testType == 'negative':
Matthew Haddone807aae2021-10-11 18:12:58 +01004971 if validator is not None:
4972 validator_info = validator(check=False, op=op)
4973 else:
4974 return None
4975
Matthew Haddon1c00b712021-10-01 15:51:03 +01004976 error_arguments = validator_info['param_reqs']
4977
4978 #Set parameters as required
4979 if error_arguments['rank'] != None:
4980 rankFilter = error_arguments['rank']
4981 else:
4982 rankFilter = cleanRankFilter
4983
4984 if error_arguments['dtype'] != None:
4985 dtypeFilter = error_arguments['dtype']
4986 else:
4987 dtypeFilter = cleanDtypeFilter
4988
4989 if error_arguments['shape'] != None:
4990 shapeFilter = error_arguments['shape']
4991 else:
4992 shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small
4993
4994 filterDict = {
4995 'shapeFilter': shapeFilter,
4996 'rankFilter': rankFilter,
4997 'dtypeFilter': dtypeFilter
4998 }
4999 return filterDict
5000
5001
Kevin Cheng550ccc52021-03-03 11:21:43 -08005002 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01005003 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08005004 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005005
5006 try:
5007 op = self.TOSA_OP_LIST[opName]
5008 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005009 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07005010
5011 # Initialize a new random number generator
5012 self.rng = np.random.default_rng(self.random_seed)
5013
Kevin Cheng550ccc52021-03-03 11:21:43 -08005014 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005015
Eric Kunzee5e26762020-10-13 16:11:07 -07005016 # Test list consists of a tuple of:
5017 # (opName, testNameStr, dtype, shapeList, argumentsList)
5018 testList = []
Matthew Haddon1c00b712021-10-01 15:51:03 +01005019 if testType == 'negative' and "error_if_validators" in op:
5020 error_if_validators = op["error_if_validators"]
5021 else:
5022 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07005023
Matthew Haddon1c00b712021-10-01 15:51:03 +01005024 for validator in error_if_validators:
5025 if validator is not None:
5026 error_name = validator(check=False, op=op)['error_name']
Matthew Haddon1c00b712021-10-01 15:51:03 +01005027 else:
5028 error_name = None
5029
5030 filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator)
Matthew Haddone807aae2021-10-11 18:12:58 +01005031 if filterDict == None:
5032 return []
Matthew Haddon1c00b712021-10-01 15:51:03 +01005033 cleanRankFilter = filterDict['rankFilter']
5034 cleanDtypeFilter = filterDict['dtypeFilter']
5035 cleanShapeFilter = filterDict['shapeFilter']
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005036 #print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01005037
5038 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01005039 for t in cleanDtypeFilter:
5040 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01005041 # Filter out by rank
5042 if shape is not None and len(shape) != r:
5043 continue
Matthew Haddon74567092021-07-16 15:38:20 +01005044 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005045 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07005046
Matthew Haddon74567092021-07-16 15:38:20 +01005047 shapeStr = self.shapeStr(shapeList[0])
5048 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07005049
Matthew Haddon74567092021-07-16 15:38:20 +01005050 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
5051 argList = []
5052 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01005053 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07005054 else:
Matthew Haddon74567092021-07-16 15:38:20 +01005055 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07005056
Matthew Haddon74567092021-07-16 15:38:20 +01005057 for argStr, args in argList:
Matthew Haddon1c00b712021-10-01 15:51:03 +01005058 if testType == 'positive':
5059 if argStr:
5060 testStr = "{}_{}_{}_{}".format(
5061 opName, shapeStr, typeStr, argStr
5062 )
5063 else:
5064 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
5065 elif testType == 'negative':
Matthew Haddone86fd342021-09-07 16:12:21 +01005066 if argStr:
5067 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
5068 opName, error_name, shapeStr, typeStr, argStr
5069 )
5070 else:
5071 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005072
5073 testList.append((opName, testStr, t, error_name, shapeList, args))
5074
5075 if testType == 'positive':
5076 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
5077 if "invalid_test_validators" in op:
5078 invalid_test_validators = op["invalid_test_validators"]
5079 clean_testList = []
5080 for test in testList:
5081 for validator_fcn in invalid_test_validators:
5082 remove_test = False
5083 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
5084 remove_test = True
5085 if not remove_test:
5086 clean_testList.append(test)
5087 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07005088
5089 return testList
5090
Matthew Haddone86fd342021-09-07 16:12:21 +01005091
5092 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07005093 try:
5094 op = self.TOSA_OP_LIST[opName]
5095 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005096 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07005097
5098 # Create a serializer
5099 self.createSerializer(opName, testStr)
5100
Kevin Cheng550ccc52021-03-03 11:21:43 -08005101 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01005102 if "error_if_validators" in op:
5103 error_if_validators = op["error_if_validators"]
5104 else:
5105 error_if_validators = None
5106
Kevin Cheng550ccc52021-03-03 11:21:43 -08005107 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07005108 num_operands = pCount + cCount
5109
5110 if isinstance(dtype_or_dtypeList, list):
5111 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07005112 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01005113 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07005114 else:
5115 dtypeList = [dtype_or_dtypeList] * (num_operands)
5116
Kevin Cheng93a16282021-08-31 16:14:03 -07005117 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01005118 assert (
5119 len(shapeList) == num_operands
5120 ), "shapeList length {} must match number of operands {}".format(
5121 len(shapeList), num_operands
5122 )
5123 assert (
5124 len(dtypeList) == num_operands
5125 ), "dtypeList length {} must match number of operands {}".format(
5126 len(dtypeList), num_operands
5127 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005128
5129 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005130 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005131 except KeyError:
5132 qgen = None
5133
5134 # Build the random tensor operands and the test
5135 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08005136
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005137 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005138
5139 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005140 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005141 else:
5142 qinfo = None
5143
5144 try:
5145 if error_if_validators is None:
5146 if qinfo is not None:
5147 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
5148 else:
5149 resultName = build_fcn(self, op, *tens, *testArgs)
5150 else:
5151 if qinfo is not None:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005152 resultName = build_fcn(self, op, *tens, *testArgs, validator_fcns=error_if_validators, error_name=error_name, qinfo=qinfo)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005153 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005154 resultName = build_fcn(self, op, *tens, *testArgs, validator_fcns=error_if_validators, error_name=error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005155 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00005156 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01005157 raise e
5158
5159 if resultName is None:
5160 print("Invalid ERROR_IF tests created")
5161
5162 # Save the serialized test
5163 self.serialize("test")
5164
5165
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005166 def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
Matthew Haddon1c00b712021-10-01 15:51:03 +01005167 pCount, cCount = op["operands"]
5168
5169 tens = []
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005170 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name == None:
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005171 # Make sure the operation does not cause value saturation - where
5172 # the number wraps due to limited number of bits to store the answer
5173 assert (
5174 pCount == 2 and cCount == 0
5175 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005176 placeholders = []
5177 add = (op["op"] == Op.ADD)
5178 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
5179 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
5180 if add:
5181 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
5182 else:
5183 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
5184
5185 # Work out the saturation limits
5186 max_i32 = (1 << 31)-1
5187 min_i32 = -(1 << 31)
5188 max_arr = np.full(shapeList[1], max_i32)
5189 min_arr = np.full(shapeList[1], min_i32)
5190
5191 # Find how much values exceed the maximum/minimums
5192 sat_max_arr = np.maximum(res_arr - max_arr, 0)
5193 sat_min_arr = np.minimum(res_arr - min_arr, 0)
5194
5195 if not add:
5196 # Swap saturation values and negate values as we need to perform opposite operations
5197 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
5198
5199 # Create new array of unsaturated values by clipping values as needed
5200 b_unsat_arr = b_arr
5201 if (sat_max_arr != 0).any():
5202 # Clip values that cause saturation
5203 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
5204 # Reduce axes in unsaturated tensor to match original tensor
5205 for axis, dim in enumerate(b_arr.shape):
5206 if dim != b_unsat_arr.shape[axis]:
5207 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
5208 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
5209
5210 if (sat_min_arr != 0).any():
5211 # Clip values that cause saturation
5212 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
5213 # Reduce axes in unsaturated tensor to match original tensor
5214 for axis, dim in enumerate(b_arr.shape):
5215 if dim != b_unsat_arr.shape[axis]:
5216 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
5217 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
5218
5219 placeholders.append(
5220 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
5221 )
5222 placeholders.append(
5223 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
5224 )
5225
5226 tens.extend(placeholders)
Jeremy Johnson8c06a652021-10-20 15:51:11 +01005227 elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[0] == DType.INT32:
5228 # Limit input tensors with cond_if_binary or while_loop to stop
5229 # saturation of add/sub ops
5230 pRemain = pCount
5231 placeholders = []
5232 for idx, shape in enumerate(shapeList[:]):
5233 arr = self.getRandTensor(shapeList[idx], DType.INT16)
5234 if pRemain > 0:
5235 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
5236 pRemain -= 1
5237 else:
5238 placeholders.append(self.ser.addConst(shape, dtypeList[idx], arr))
5239
5240 tens.extend(placeholders)
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005241 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
5242 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005243 assert (
5244 pCount == 2 and cCount == 0
5245 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08005246
5247 placeholders = []
5248 for idx, shape in enumerate(shapeList[:]):
5249 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07005250 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08005251 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07005252 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08005253 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07005254 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08005255 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005256 elif error_name == ErrorIf.WrongInputType:
5257 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Chengaee1fac2020-11-11 13:54:06 -08005258 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005259 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08005260 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005261 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07005262 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08005263
5264 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01005265 elif op["op"] == Op.SELECT:
5266 # Set datatype of condition tensor to boolean
5267 dtypeList[0] = DType.BOOL
5268 tens.extend(
5269 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
5270 )
5271 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005272 elif op["op"] == Op.INTDIV and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005273 assert (
5274 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01005275 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005276
5277 placeholders = []
5278
Matthew Haddon459443c2021-08-23 16:43:13 +01005279 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005280 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07005281 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005282 while True:
5283 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
5284 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
5285
5286 if (divisor_arr == 0).any():
5287 continue
5288
Kevin Cheng47315e12021-05-13 17:41:28 -07005289 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005290 continue
5291
5292 break
5293
5294 placeholders.append(
5295 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
5296 )
5297 placeholders.append(
5298 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
5299 )
5300
5301 tens.extend(placeholders)
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005302 elif op["op"] == Op.MUL and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005303 assert (
5304 pCount == 2 and cCount == 0
5305 ), "Op.MUL must have 2 placeholders, 0 consts"
5306
5307 if dtypeList[0] == DType.FLOAT:
5308 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
5309 else:
5310 placeholders = []
5311
5312 # Make sure multiply result in int32 range
5313 shift = testArgs[0]
5314 if dtypeList[0] == DType.INT8:
5315 num_bits = 8
5316 elif dtypeList[0] == DType.INT16:
5317 num_bits = 16
5318 elif dtypeList[0] == DType.INT32:
5319 num_bits = 32
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005320 elif error_name == ErrorIf.WrongInputType:
5321 num_bits = 8
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005322 else:
5323 raise Exception("OpMul: invalid input dtype")
5324
5325 for idx, shape in enumerate(shapeList[:]):
5326 low = -(2 ** (num_bits - 1))
5327 high = (2 ** (num_bits - 1)) - 1
5328
5329 a_arr = np.int32(
5330 self.rng.integers(low=low, high=high, size=shapeList[0])
5331 )
5332 b_arr = np.int32(
5333 self.rng.integers(low=low, high=high, size=shapeList[1])
5334 )
5335
5336 i = 0
5337 while True:
5338
5339 a_arr_64 = a_arr.astype(np.int64)
5340 b_arr_64 = b_arr.astype(np.int64)
5341
5342 if shift > 0:
5343 rounding = 1 << (shift - 1)
5344 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
5345 else:
5346 result_arr = a_arr_64 * b_arr_64
5347
5348 if (result_arr > -(2 ** 31)).all() and (
5349 result_arr <= ((2 ** 31) - 1)
5350 ).all():
5351 break
5352
5353 i = i + 1
5354 a_arr = a_arr // 2
5355 b_arr = b_arr // 2
5356
5357 placeholders.append(
5358 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
5359 )
5360 placeholders.append(
5361 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
5362 )
5363
5364 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01005365 elif op["op"] == Op.CONCAT:
5366 count = len(shapeList) - self.args.num_const_inputs_concat
5367 if count < 1:
5368 count = 1
5369 if self.args.num_const_inputs_concat == 0:
5370 count = len(shapeList)
5371
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005372 # Ensure axis is an int
5373 testArgs[0] = int(testArgs[0])
5374
5375 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0], error_name)
5376
Matthew Haddon818ab902021-07-27 09:12:49 +01005377 tens.extend(
5378 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
5379 )
5380 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08005381 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07005382 tens.extend(
5383 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
5384 )
5385 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07005386
Matthew Haddon1c00b712021-10-01 15:51:03 +01005387 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07005388
5389 def createDynamicOpLists(self):
5390
5391 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07005392 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005393
Kevin Cheng1533b852021-09-01 12:51:58 -07005394 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005395 testName = "conv2d_{}x{}".format(k[0], k[1])
5396 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
5397 self.TOSA_OP_LIST[testName]["filter"] = k
5398 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07005399
Kevin Cheng550ccc52021-03-03 11:21:43 -08005400 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
5401 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
5402 "depthwise_conv2d_TEMPLATE"
5403 ].copy()
5404 self.TOSA_OP_LIST[testName]["filter"] = k
5405 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07005406
Kevin Cheng550ccc52021-03-03 11:21:43 -08005407 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
5408 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
5409 "transpose_conv2d_TEMPLATE"
5410 ].copy()
5411 self.TOSA_OP_LIST[testName]["filter"] = k
5412 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07005413
Kevin Cheng1533b852021-09-01 12:51:58 -07005414 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
5415 for k in KERNELS_3D:
5416 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
5417 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
5418 self.TOSA_OP_LIST[testName]["filter"] = k
5419 self.TOSA_OP_LIST[testName]["template"] = False
5420
Eric Kunzee5e26762020-10-13 16:11:07 -07005421 # Delete any templates after having created any dynamic ops
5422 # This is a two-pass operation because it's bad practice to delete
5423 # keys from dictionaries while iterating
5424 keyList = []
5425 for k in self.TOSA_OP_LIST:
5426 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005427 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07005428 keyList.append(k)
5429 continue
5430 except KeyError:
5431 pass
5432
5433 for k in keyList:
5434 del self.TOSA_OP_LIST[k]
5435
5436 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005437 """Fill in default fields for ops if they aren't already specified.
5438 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07005439 for op in self.TOSA_OP_LIST:
5440
5441 # Required fields
5442 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005443 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005444 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005445 raise Exception(
5446 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
5447 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005448
5449 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005450 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005451 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005452 raise Exception(
5453 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
5454 op
5455 )
5456 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005457
5458 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005459 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005460 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005461 raise Exception(
5462 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
5463 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005464
5465 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005466 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005467 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005468 raise Exception(
5469 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
5470 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005471
5472 # Put in default rank range, if missing
5473 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005474 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005475 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005476 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07005477
5478 # Tensor operator list
5479 # 'op': op name
5480 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08005481 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
5482 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07005483 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
5484 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08005485 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07005486
Kevin Cheng550ccc52021-03-03 11:21:43 -08005487 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
5488 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07005489
Kevin Cheng550ccc52021-03-03 11:21:43 -08005490 TYPE_BOOL = [DType.BOOL]
5491 TYPE_FI32 = [DType.FLOAT, DType.INT32]
5492 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
5493 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07005494
Kevin Cheng550ccc52021-03-03 11:21:43 -08005495 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07005496
Kevin Cheng1533b852021-09-01 12:51:58 -07005497 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07005498 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07005499 [DType.INT8, DType.INT8, DType.INT32],
5500 [DType.INT16, DType.INT8, DType.INT48],
5501 DType.FLOAT,
5502 ]
5503
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01005504 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07005505
5506 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08005507 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08005508 "argmax": {
5509 "op": Op.ARGMAX,
5510 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005511 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005512 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
5513 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005514 "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evArgmaxOutputRankMismatch,
5515 TosaErrorValidator.evArgmaxOutputShapeMismatch, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
5516 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005517 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005518 "avg_pool2d": {
5519 "op": Op.AVG_POOL2D,
5520 "operands": (1, 0),
5521 "rank": (4, 4),
5522 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
5523 "qgen": TosaQuantGen.qgUnary,
5524 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00005525 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005526 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
5527 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5528 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
5529 TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005530 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005531 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08005532 "conv2d_TEMPLATE": {
5533 "op": Op.CONV2D,
5534 "operands": (1, 2),
5535 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01005536 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005537 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07005538 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00005539 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
5540 "error_if_validators": (
5541 TosaErrorValidator.evWrongInputType,
5542 TosaErrorValidator.evWrongOutputType,
5543 TosaErrorValidator.evWrongInputList,
5544 TosaErrorValidator.evWrongOutputList,
5545 TosaErrorValidator.evInputZeroPointNotZero,
5546 TosaErrorValidator.evWeightZeroPointNotZero,
5547 TosaErrorValidator.evPadSmallerZero,
5548 TosaErrorValidator.evStrideSmallerOne,
5549 TosaErrorValidator.evDilationSmallerOne,
5550 TosaErrorValidator.evWrongRank,
5551 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005552 "template": True,
5553 },
Kevin Cheng1533b852021-09-01 12:51:58 -07005554 # Templated operator. Filled in by createDynamicOpLists
5555 "conv3d_TEMPLATE": {
5556 "op": Op.CONV3D,
5557 "operands": (1, 2),
5558 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01005559 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07005560 "qgen": TosaQuantGen.qgConv,
5561 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00005562 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
5563 "error_if_validators": (
5564 TosaErrorValidator.evWrongInputType,
5565 TosaErrorValidator.evWrongOutputType,
5566 TosaErrorValidator.evWrongInputList,
5567 TosaErrorValidator.evWrongOutputList,
5568 TosaErrorValidator.evInputZeroPointNotZero,
5569 TosaErrorValidator.evWeightZeroPointNotZero,
5570 TosaErrorValidator.evPadSmallerZero,
5571 TosaErrorValidator.evStrideSmallerOne,
5572 TosaErrorValidator.evDilationSmallerOne,
5573 TosaErrorValidator.evWrongRank,
5574 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07005575 "template": True,
5576 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005577 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08005578 "depthwise_conv2d_TEMPLATE": {
5579 "op": Op.DEPTHWISE_CONV2D,
5580 "operands": (1, 2),
5581 "filter": [1, 1],
5582 "rank": (4, 4),
5583 "build_fcn": (
5584 build_depthwise_conv2d,
5585 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01005586 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005587 ),
5588 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07005589 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00005590 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
5591 "error_if_validators": (
5592 TosaErrorValidator.evWrongInputType,
5593 TosaErrorValidator.evWrongOutputType,
5594 TosaErrorValidator.evWrongInputList,
5595 TosaErrorValidator.evWrongOutputList,
5596 TosaErrorValidator.evInputZeroPointNotZero,
5597 TosaErrorValidator.evWeightZeroPointNotZero,
5598 TosaErrorValidator.evPadSmallerZero,
5599 TosaErrorValidator.evStrideSmallerOne,
5600 TosaErrorValidator.evDilationSmallerOne,
5601 TosaErrorValidator.evWrongRank,
5602 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005603 "template": True,
5604 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005605 "fully_connected": {
5606 "op": Op.FULLY_CONNECTED,
5607 "operands": (1, 2),
5608 "rank": (2, 2),
5609 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
5610 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07005611 "types": TYPE_CONV,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005612 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evWrongRank,
5613 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005614 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005615 "matmul": {
5616 "op": Op.MATMUL,
5617 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07005618 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08005619 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
5620 "qgen": TosaQuantGen.qgMatmul,
5621 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005622 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
5623 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005624 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005625 "max_pool2d": {
5626 "op": Op.MAX_POOL2D,
5627 "operands": (1, 0),
5628 "rank": (4, 4),
5629 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
5630 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00005631 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005632 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
5633 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5634 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005635 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005636 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08005637 "transpose_conv2d_TEMPLATE": {
5638 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07005639 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005640 "rank": (4, 4),
5641 "build_fcn": (
5642 build_transpose_conv2d,
5643 TosaTensorGen.tgTransposeConv2D,
5644 TosaArgGen.agTransposeConv2D,
5645 ),
5646 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07005647 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00005648 "invalid_test_validators": (
5649 TosaInvalidValidator.ivHeightWidthInvalid,
5650 TosaInvalidValidator.ivNonPositiveOutputShape,
5651 ),
5652 "error_if_validators": (
5653 TosaErrorValidator.evWrongInputType,
5654 TosaErrorValidator.evWrongOutputType,
5655 TosaErrorValidator.evWrongInputList,
5656 TosaErrorValidator.evWrongOutputList,
5657 TosaErrorValidator.evInputZeroPointNotZero,
5658 TosaErrorValidator.evWeightZeroPointNotZero,
5659 TosaErrorValidator.evPadSmallerZero,
5660 TosaErrorValidator.evStrideSmallerOne,
5661 TosaErrorValidator.evDilationSmallerOne,
5662 TosaErrorValidator.evWrongRank,
5663 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005664 "template": True,
5665 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005666 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08005667 "clamp": {
5668 "op": Op.CLAMP,
5669 "operands": (1, 0),
5670 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
5671 "types": TYPE_NARROW_INT_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005672 "error_if_validators": (TosaErrorValidator.evMaxSmallerMin, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5673 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005674 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08005675 "sigmoid": {
5676 "op": Op.SIGMOID,
5677 "operands": (1, 0),
5678 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
5679 "types": TYPE_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005680 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5681 TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005682 },
5683 "tanh": {
5684 "op": Op.TANH,
5685 "operands": (1, 0),
5686 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
5687 "types": TYPE_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005688 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5689 TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005690 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005691 # Elementwise Binary Operators
5692 "add": {
5693 "op": Op.ADD,
5694 "operands": (2, 0),
5695 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5696 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005697 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005698 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005699 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005700 "arithmetic_right_shift": {
5701 "op": Op.ARITHMETIC_RIGHT_SHIFT,
5702 "operands": (2, 0),
5703 "build_fcn": (
5704 build_arithmetic_right_shift,
5705 TosaTensorGen.tgBroadcastFuzz,
5706 TosaArgGen.agArithmeticRightShift,
5707 ),
5708 "types": TYPE_INT,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005709 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5710 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005711 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005712 "bitwise_and": {
5713 "op": Op.BITWISE_AND,
5714 "operands": (2, 0),
5715 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5716 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005717 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005718 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005719 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005720 "bitwise_or": {
5721 "op": Op.BITWISE_OR,
5722 "operands": (2, 0),
5723 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5724 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005725 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005726 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005727 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005728 "bitwise_xor": {
5729 "op": Op.BITWISE_XOR,
5730 "operands": (2, 0),
5731 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5732 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005733 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005734 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005735 },
Matthew Haddon459443c2021-08-23 16:43:13 +01005736 "intdiv": {
5737 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005738 "operands": (2, 0),
5739 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5740 "types": [DType.INT32],
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005741 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005742 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005743 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005744 "logical_and": {
5745 "op": Op.LOGICAL_AND,
5746 "operands": (2, 0),
5747 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5748 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005749 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005750 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005751 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005752 "logical_left_shift": {
5753 "op": Op.LOGICAL_LEFT_SHIFT,
5754 "operands": (2, 0),
5755 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5756 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005757 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005758 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005759 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005760 "logical_right_shift": {
5761 "op": Op.LOGICAL_RIGHT_SHIFT,
5762 "operands": (2, 0),
5763 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5764 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005765 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005766 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005767 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005768 "logical_or": {
5769 "op": Op.LOGICAL_OR,
5770 "operands": (2, 0),
5771 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5772 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005773 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005774 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005775 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005776 "logical_xor": {
5777 "op": Op.LOGICAL_XOR,
5778 "operands": (2, 0),
5779 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5780 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005781 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005782 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005783 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005784 "maximum": {
5785 "op": Op.MAXIMUM,
5786 "operands": (2, 0),
5787 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5788 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005789 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005790 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005791 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005792 "minimum": {
5793 "op": Op.MINIMUM,
5794 "operands": (2, 0),
5795 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5796 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005797 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005798 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005799 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005800 "mul": {
5801 "op": Op.MUL,
5802 "operands": (2, 0),
5803 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
5804 "types": TYPE_INT_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005805 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005806 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evRankMismatch, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005807 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005808 "pow": {
5809 "op": Op.POW,
5810 "operands": (2, 0),
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005811 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08005812 "types": TYPE_FP,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005813 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005814 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005815 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005816 "sub": {
5817 "op": Op.SUB,
5818 "operands": (2, 0),
5819 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5820 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005821 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005822 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005823 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005824 "table": {
5825 "op": Op.TABLE,
5826 # Use the automatic generation functions to create the input array
5827 # but create the table tensor in the build function, as it may be
5828 # a different type from the input
5829 "operands": (1, 0),
Kevin Chengfe392ce2021-10-18 21:51:55 +00005830 "build_fcn": (build_table, TosaTensorGen.tgBasic, TosaArgGen.agTable),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005831 "types": [DType.INT8, DType.INT16],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005832 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5833 TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005834 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005835 # Elementwise Unary operators
5836 "abs": {
5837 "op": Op.ABS,
5838 "operands": (1, 0),
5839 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5840 "types": TYPE_FI32,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005841 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5842 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005843 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005844 "bitwise_not": {
5845 "op": Op.BITWISE_NOT,
5846 "operands": (1, 0),
5847 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5848 "types": TYPE_INT,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005849 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5850 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005851 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005852 "ceil": {
5853 "op": Op.CEIL,
5854 "operands": (1, 0),
5855 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5856 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005857 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5858 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005859 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005860 "clz": {
5861 "op": Op.CLZ,
5862 "operands": (1, 0),
5863 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5864 "types": [DType.INT32],
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005865 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5866 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005867 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005868 "exp": {
5869 "op": Op.EXP,
5870 "operands": (1, 0),
5871 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5872 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005873 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5874 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005875 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005876 "floor": {
5877 "op": Op.FLOOR,
5878 "operands": (1, 0),
5879 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5880 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005881 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5882 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005883 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005884 "log": {
5885 "op": Op.LOG,
5886 "operands": (1, 0),
5887 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5888 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005889 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5890 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005891 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005892 "logical_not": {
5893 "op": Op.LOGICAL_NOT,
5894 "operands": (1, 0),
5895 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5896 "types": TYPE_BOOL,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005897 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5898 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005899 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005900 "negate": {
5901 "op": Op.NEGATE,
5902 "operands": (1, 0),
5903 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5904 "qgen": TosaQuantGen.qgUnary,
5905 "types": TYPE_INT_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005906 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
5907 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5908 TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005909 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005910 "reciprocal": {
5911 "op": Op.RECIPROCAL,
5912 "operands": (1, 0),
5913 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5914 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005915 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5916 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005917 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005918 "rsqrt": {
5919 "op": Op.RSQRT,
5920 "operands": (1, 0),
5921 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5922 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005923 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5924 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005925 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005926 # Elementwise Ternary operators
5927 "select": {
5928 "op": Op.SELECT,
5929 "operands": (3, 0),
5930 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
5931 "types": TYPE_FIB,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005932 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5933 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005934 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005935 # Comparison operators
5936 "equal": {
5937 "op": Op.EQUAL,
5938 "operands": (2, 0),
5939 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
5940 "types": TYPE_FI32,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005941 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5942 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005943 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005944 "greater_equal": {
5945 "op": Op.GREATER_EQUAL,
5946 "operands": (2, 0),
5947 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
5948 "types": TYPE_FI32,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005949 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5950 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005951 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005952 "greater": {
5953 "op": Op.GREATER,
5954 "operands": (2, 0),
5955 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
5956 "types": TYPE_FI32,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005957 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5958 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005959 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005960 # Reduction operators
5961 "reduce_all": {
5962 "op": Op.REDUCE_ALL,
5963 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005964 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08005965 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
5966 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01005967 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
5968 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
5969 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005970 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005971 "reduce_any": {
5972 "op": Op.REDUCE_ANY,
5973 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005974 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08005975 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
5976 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01005977 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
5978 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
5979 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005980 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005981 "reduce_max": {
5982 "op": Op.REDUCE_MAX,
5983 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005984 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08005985 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
5986 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01005987 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
5988 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
5989 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005990 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005991 "reduce_min": {
5992 "op": Op.REDUCE_MAX,
5993 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005994 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08005995 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
5996 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01005997 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
5998 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
5999 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08006000 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006001 "reduce_product": {
6002 "op": Op.REDUCE_PRODUCT,
6003 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006004 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006005 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6006 "types": TYPE_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01006007 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
6008 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
6009 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08006010 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006011 "reduce_sum": {
6012 "op": Op.REDUCE_SUM,
6013 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006014 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006015 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6016 "types": TYPE_FI32,
Matthew Haddond6ce7252021-09-29 15:35:44 +01006017 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
6018 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
6019 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08006020 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006021 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08006022 "concat": {
6023 "op": Op.CONCAT,
6024 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01006025 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006026 "types": TYPE_FIB,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006027 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evConcatInputRankMismatch,
Matthew Haddon01c359d2021-10-15 16:30:48 +01006028 TosaErrorValidator.evConcatShapeSumMismatch, TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType,
6029 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006030 },
6031 "pad": {
6032 "op": Op.PAD,
6033 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01006034 "rank": (1, 5),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006035 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
6036 "qgen": TosaQuantGen.qgPad,
6037 "types": TYPE_FIB,
Jeremy Johnson27cf5432021-11-16 11:12:17 +00006038 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evPadSmallerZero,
Matthew Haddone807aae2021-10-11 18:12:58 +01006039 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006040 },
6041 "reshape": {
6042 "op": Op.RESHAPE,
6043 "operands": (1, 0),
6044 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
6045 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01006046 "error_if_validators": (TosaErrorValidator.evTensorSizeInputOutputMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
6047 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006048 },
6049 "reverse": {
6050 "op": Op.REVERSE,
6051 "operands": (1, 0),
6052 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6053 "types": TYPE_FIB,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006054 "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evWrongInputType,
6055 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006056 },
6057 "slice": {
6058 "op": Op.SLICE,
6059 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01006060 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006061 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
6062 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01006063 "error_if_validators": (TosaErrorValidator.evStartSmallerZero, TosaErrorValidator.evSizeSmallerEqualZero, TosaErrorValidator.evStartSizeOutsideBounds,
6064 TosaErrorValidator.evSizeOutputShapeMismatch, TosaErrorValidator.evInputSizeStartLengthMismatch, TosaErrorValidator.evWrongRank,
6065 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006066 },
6067 "tile": {
6068 "op": Op.TILE,
6069 "operands": (1, 0),
6070 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
6071 "types": TYPE_FIB,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006072 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
6073 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006074 },
6075 "transpose": {
6076 "op": Op.TRANSPOSE,
6077 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01006078 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006079 "build_fcn": (
6080 build_transpose,
6081 TosaTensorGen.tgBasic,
6082 TosaArgGen.agTranspose,
6083 ),
6084 "types": TYPE_FIB,
Jeremy Johnson27cf5432021-11-16 11:12:17 +00006085 "error_if_validators": (TosaErrorValidator.evIndexOutsideBounds, TosaErrorValidator.evIndexUsedTwice,
Matthew Haddone807aae2021-10-11 18:12:58 +01006086 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006087 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006088 # Data nodes
6089 "const": {
6090 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07006091 "operands": (0, 1),
6092 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08006093 "types": TYPE_FIB,
6094 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006095 "identity": {
6096 "op": Op.IDENTITY,
6097 "operands": (1, 0),
6098 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6099 "types": TYPE_FIB,
6100 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006101 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08006102 "gather": {
6103 "op": Op.GATHER,
6104 # Only specify 'values' tensor here. 'indices' is generated in op building stage
6105 "operands": (1, 0),
6106 "rank": (3, 3),
6107 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
6108 "types": TYPE_INT_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006109 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006110 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evWrongRank)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006111 },
6112 "scatter": {
6113 "op": Op.SCATTER,
6114 # Only specify 'values_in' tensor here.
6115 #'indices' and 'input' are generated in op building stage
6116 "operands": (2, 0),
6117 "rank": (3, 3),
6118 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
6119 "types": TYPE_INT_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006120 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006121 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evWrongRank)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006122 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006123 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08006124 "resize": {
6125 "op": Op.RESIZE,
6126 "operands": (1, 0),
6127 "rank": (4, 4),
6128 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
6129 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01006130 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
6131 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
6132 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
Matthew Haddon848efb42021-09-09 12:30:53 +01006133 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01006134 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
6135 TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006136 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006137 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08006138 "cast": {
6139 "op": Op.CAST,
6140 "operands": (1, 0),
6141 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
6142 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006143 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
6144 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006145 },
6146 "rescale": {
6147 "op": Op.RESCALE,
6148 "operands": (1, 0),
Matthew Haddonc2025212021-10-08 21:21:05 +01006149 "rank": (1,4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006150 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01006151 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Matthew Haddonc2025212021-10-08 21:21:05 +01006152 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, TosaErrorValidator.evScaleTrue,
6153 TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
6154 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006155 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006156 # Custom
6157 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08006158 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07006159 # Two varients of cond_if, one that generates one of two constant tensors (no
6160 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
6161 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006162 "cond_if_const": {
6163 "op": Op.COND_IF,
6164 "operands": (0, 2),
6165 "build_fcn": (
6166 build_cond_if_const,
6167 TosaTensorGen.tgBasic,
6168 TosaArgGen.agCondIf,
6169 ),
6170 "types": [DType.BOOL],
Matthew Haddon630c17c2021-10-14 15:05:41 +01006171 "error_if_validators": (TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006172 },
6173 "cond_if_binary": {
6174 "op": Op.COND_IF,
6175 "operands": (2, 0),
6176 "build_fcn": (
6177 build_cond_if_binary,
6178 TosaTensorGen.tgBasic,
6179 TosaArgGen.agCondIf,
6180 ),
Les Bell6040b4d2021-10-11 12:50:31 +01006181 "types": TYPE_INT_FP,
Matthew Haddon630c17c2021-10-14 15:05:41 +01006182 "error_if_validators": (TosaErrorValidator.evInputListThenGraphMismatch, TosaErrorValidator.evInputListElseGraphMismatch,
6183 TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006184 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006185 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08006186 "while_loop": {
6187 "op": Op.WHILE_LOOP,
6188 "operands": (0, 1),
6189 "build_fcn": (
6190 build_while_loop,
6191 TosaTensorGen.tgBasic,
6192 TosaArgGen.agWhileLoop,
6193 ),
6194 "types": [DType.INT32],
Matthew Haddon630c17c2021-10-14 15:05:41 +01006195 "error_if_validators": (TosaErrorValidator.evInputListOutputListMismatch, TosaErrorValidator.evInputListCondGraphMismatch,
6196 TosaErrorValidator.evInputListBodyGraphInputMismatch, TosaErrorValidator.evInputListBodyGraphOutputMismatch,
6197 TosaErrorValidator.evCondGraphOutputNotMatchingBool)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006198 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006199 }
6200
Kevin Cheng550ccc52021-03-03 11:21:43 -08006201
Eric Kunzee5e26762020-10-13 16:11:07 -07006202class OutputShaper:
6203 # Methods in this class compute the expected output shape and datatype
6204 # for common classes of operations
6205 def __init__(self):
6206 pass
6207
6208 # These methods return arguments that can be used for
6209 # creating a new output tensor
6210 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01006211 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
6212 if error_name != ErrorIf.RankMismatch:
6213 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006214 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07006215
6216 shape = []
6217 for i in range(len(a.shape)):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01006218 if a.shape[i] == 1 and error_name == None:
Eric Kunzee5e26762020-10-13 16:11:07 -07006219 shape.append(b.shape[i])
6220 else:
6221 shape.append(a.shape[i])
6222
Matthew Haddoneacff9a2021-09-24 14:42:13 +01006223 if error_name == ErrorIf.WrongOutputType:
6224 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6225 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6226 outputDType = rng.choice(wrong_dtypes)
6227 else:
6228 outputDType = a.dtype
6229
6230 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006231
6232 @staticmethod
6233 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08006234 assert len(a.shape) == len(b.shape)
6235 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07006236
6237 shape = []
6238 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08006239 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07006240 shape.append(a.shape[i])
6241
Kevin Cheng550ccc52021-03-03 11:21:43 -08006242 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006243
6244 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01006245 def unaryOp(ser, rng, a, error_name=None):
6246 if error_name == ErrorIf.WrongOutputType:
6247 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6248 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6249 outputDType = rng.choice(wrong_dtypes)
6250 else:
6251 outputDType = a.dtype
6252
6253 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006254
6255 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006256 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00006257 if error_name != ErrorIf.RankMismatch:
6258 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006259 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07006260
6261 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00006262 for i in range(len(cond.shape)):
6263 if cond.shape[i] == 1 and error_name == None:
6264 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
6265 else:
6266 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07006267
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006268 if error_name == ErrorIf.WrongOutputType:
6269 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6270 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6271 outputDType = rng.choice(wrong_dtypes)
6272 else:
6273 outputDType = a.dtype
6274
6275 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006276
6277 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006278 def binaryComparisonOp(ser, rng, a, b , error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00006279 if error_name != ErrorIf.RankMismatch:
6280 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006281 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07006282
6283 # Do broadcast
6284 shape = []
6285 for i in range(len(a.shape)):
6286 if a.shape[i] == 1:
6287 shape.append(b.shape[i])
6288 else:
6289 shape.append(a.shape[i])
6290
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006291 if error_name == ErrorIf.WrongOutputType:
6292 wrong_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6293 outputDType = rng.choice(wrong_dtypes)
6294 else:
6295 outputDType = DType.BOOL
6296
6297 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006298
6299 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01006300 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006301 shape = a.shape.copy()
Matthew Haddond6ce7252021-09-29 15:35:44 +01006302 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ShapeOfAxisNotOne]:
6303 shape[axis] = 1
6304 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
6305 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07006306
Matthew Haddond6ce7252021-09-29 15:35:44 +01006307 if error_name == ErrorIf.WrongOutputType:
6308 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6309 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6310 outputDType = rng.choice(wrong_dtypes)
6311 else:
6312 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07006313
Matthew Haddond6ce7252021-09-29 15:35:44 +01006314 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006315
6316 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006317 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006318 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006319
6320 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
6321 del shape[axis]
6322
6323 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
6324 remove = rng.choice([True, False])
6325 if remove and len(shape) > 1:
6326 del shape[0]
6327 else:
6328 shape.append(1)
6329 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
6330 for i in range(len(shape)):
6331 shape[i] = shape[i] + rng.integers(1, 10)
6332
6333 if error_name == ErrorIf.WrongOutputType:
6334 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6335 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
6336 outputDType = rng.choice(wrong_dtypes)
6337 else:
6338 outputDType = DType.INT32
6339
6340 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006341
6342 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00006343 def conv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006344
6345 # IFM: NHWC
6346 # Filter: OHWI
6347 # OFM: NHWC
6348
6349 if len(padding) == 2:
6350 # Expand padding to 4 parameters in the case of transpose_conv2d
6351 # From H,W to T,B,L,R
6352 padding = [padding[0], padding[0], padding[1], padding[1]]
6353
Kevin Cheng550ccc52021-03-03 11:21:43 -08006354 h = (
6355 ifm.shape[1]
6356 - filter.shape[1]
6357 - (filter.shape[1] - 1) * (dilations[0] - 1)
6358 + padding[0]
6359 + padding[1]
6360 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07006361
Kevin Cheng550ccc52021-03-03 11:21:43 -08006362 w = (
6363 ifm.shape[2]
6364 - filter.shape[2]
6365 - (filter.shape[2] - 1) * (dilations[1] - 1)
6366 + padding[2]
6367 + padding[3]
6368 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07006369
Les Bell0e027d42021-11-09 14:42:14 +00006370 # Avoid illegal dimensions, which can be generated in error_if tests
6371 h = max(h, 1)
6372 w = max(w, 1)
6373
Eric Kunzee5e26762020-10-13 16:11:07 -07006374 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
6375
Kevin Cheng3a478572021-01-22 17:21:02 -08006376 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006377 out_dtype = DType.INT32
6378 elif ifm.dtype == DType.INT16:
6379 out_dtype = DType.INT48
6380 elif ifm.dtype == DType.FLOAT:
6381 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00006382 elif error_name == ErrorIf.WrongInputType:
6383 # Pick some potentially correct output dtype if input type is incorrect
6384 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006385 else:
Les Bell0e027d42021-11-09 14:42:14 +00006386 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
6387
6388 if error_name == ErrorIf.WrongOutputType:
6389 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
6390 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07006391
Kevin Cheng550ccc52021-03-03 11:21:43 -08006392 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006393
6394 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00006395 def conv3dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -07006396
6397 # IFM: NDHWC
6398 # Filter: ODHWI
6399 # OFM: NDHWC
6400
6401 d = (
6402 ifm.shape[1]
6403 - filter.shape[1]
6404 - (filter.shape[1] - 1) * (dilations[0] - 1)
6405 + padding[0]
6406 + padding[1]
6407 ) // strides[0] + 1
6408
6409 h = (
6410 ifm.shape[2]
6411 - filter.shape[2]
6412 - (filter.shape[2] - 1) * (dilations[1] - 1)
6413 + padding[2]
6414 + padding[3]
6415 ) // strides[1] + 1
6416
6417 w = (
6418 ifm.shape[3]
6419 - filter.shape[3]
6420 - (filter.shape[3] - 1) * (dilations[2] - 1)
6421 + padding[4]
6422 + padding[5]
6423 ) // strides[2] + 1
6424
Les Bell0e027d42021-11-09 14:42:14 +00006425 # Avoid illegal dimensions, which can be generated in error_if tests
6426 d = max(d, 1)
6427 h = max(h, 1)
6428 w = max(w, 1)
6429
Kevin Cheng1533b852021-09-01 12:51:58 -07006430 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
6431
6432 if ifm.dtype == DType.INT8:
6433 out_dtype = DType.INT32
6434 elif ifm.dtype == DType.INT16:
6435 out_dtype = DType.INT48
6436 elif ifm.dtype == DType.FLOAT:
6437 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00006438 elif error_name == ErrorIf.WrongInputType:
6439 # Pick some potentially correct output dtype if input type is incorrect
6440 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07006441 else:
Les Bell0e027d42021-11-09 14:42:14 +00006442 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
6443
6444 if error_name == ErrorIf.WrongOutputType:
6445 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
6446 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07006447
6448 return ser.addOutput(ofm_shape, out_dtype)
6449
6450 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00006451 def depthwiseConv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006452 # IFM: NHWC
6453 # Filter: HWCM
6454 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08006455 h = (
6456 ifm.shape[1]
6457 - filter.shape[0]
6458 - (filter.shape[0] - 1) * (dilations[0] - 1)
6459 + padding[0]
6460 + padding[1]
6461 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07006462
Kevin Cheng550ccc52021-03-03 11:21:43 -08006463 w = (
6464 ifm.shape[2]
6465 - filter.shape[1]
6466 - (filter.shape[1] - 1) * (dilations[1] - 1)
6467 + padding[2]
6468 + padding[3]
6469 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07006470
Les Bell0e027d42021-11-09 14:42:14 +00006471 # Avoid illegal dimensions, which can be generated in error_if tests
6472 h = max(h, 1)
6473 w = max(w, 1)
6474
Eric Kunzee5e26762020-10-13 16:11:07 -07006475 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
6476
Kevin Cheng3a478572021-01-22 17:21:02 -08006477 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006478 out_dtype = DType.INT32
6479 elif ifm.dtype == DType.INT16:
6480 out_dtype = DType.INT48
6481 elif ifm.dtype == DType.FLOAT:
6482 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00006483 elif error_name == ErrorIf.WrongInputType:
6484 # Pick some potentially correct output dtype if input type is incorrect
6485 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006486 else:
Les Bell0e027d42021-11-09 14:42:14 +00006487 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
6488
6489 if error_name == ErrorIf.WrongOutputType:
6490 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
6491 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07006492
Kevin Cheng550ccc52021-03-03 11:21:43 -08006493 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006494
6495 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01006496 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006497 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01006498 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006499 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01006500 h = 1
6501 w = 1
6502 else:
6503 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
6504 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
6505
6506 if error_name == ErrorIf.PoolingOutputShapeMismatch:
6507 choices = [1, 2, 3, 4, 5]
6508 h = h + rng.choice(choices)
6509 w = w + rng.choice(choices)
Eric Kunzee5e26762020-10-13 16:11:07 -07006510
Eric Kunzee5e26762020-10-13 16:11:07 -07006511 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01006512
6513 if error_name == ErrorIf.WrongOutputType:
6514 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6515 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
6516 outputDType = rng.choice(wrong_dtypes)
6517 else:
6518 outputDType = ifm.dtype
6519
6520 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006521
6522 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006523 def fullyConnectedOp(ser, rng, input, filter, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006524 # input: N, IC
6525 # filter: OC, IC
6526 # output: N, OC
6527
6528 output_shape = [input.shape[0], filter.shape[0]]
6529
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006530 if error_name == ErrorIf.WrongOutputType:
6531 if input.dtype == DType.INT8:
6532 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
6533 elif input.dtype == DType.INT16:
6534 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
6535 elif input.dtype == DType.FLOAT:
6536 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
6537 out_dtype = rng.choice(a=incorrect_types)
6538 elif input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006539 out_dtype = DType.INT32
6540 elif input.dtype == DType.INT16:
6541 out_dtype = DType.INT48
6542 elif input.dtype == DType.FLOAT:
6543 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006544 elif error_name == ErrorIf.WrongInputType:
6545 # Pick some potentially correct output dtype if input type is incorrect
6546 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006547 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08006548 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07006549
Kevin Cheng550ccc52021-03-03 11:21:43 -08006550 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006551
6552 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006553 def matmulOp(ser, rng, a, b, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07006554 # a: N, H, C
6555 # b: N, C, W
6556 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07006557
Kevin Cheng2d60f002021-06-09 14:18:32 -07006558 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07006559
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006560 if error_name == ErrorIf.WrongOutputType:
6561 if a.dtype == DType.INT8:
6562 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
6563 elif a.dtype == DType.INT16:
6564 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
6565 elif a.dtype == DType.FLOAT:
6566 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
6567 out_dtype = rng.choice(a=incorrect_types)
6568 elif a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006569 out_dtype = DType.INT32
6570 elif a.dtype == DType.INT16:
6571 out_dtype = DType.INT48
6572 elif a.dtype == DType.FLOAT:
6573 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006574 elif error_name == ErrorIf.WrongInputType:
6575 # Pick some potentially correct output dtype if input type is incorrect
6576 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006577 else:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006578 raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07006579
Kevin Cheng550ccc52021-03-03 11:21:43 -08006580 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006581
6582 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006583 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01006584 input1 = a[0]
6585 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07006586
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006587 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01006588 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006589 if not (
6590 # unable to concat tensors of different ranks
6591 error_name == ErrorIf.ConcatInputRankMismatch
6592 # unable to concat tensors along an invalid axis
6593 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006594 ):
6595 for tensor in remaining_inputs:
6596 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07006597
Matthew Haddon01c359d2021-10-15 16:30:48 +01006598 if error_name == ErrorIf.ConcatShapeSumMismatch:
6599 output_shape[axis] += rng.integers(5, 10)
6600
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006601 if error_name == ErrorIf.WrongOutputType:
6602 all_dtypes = {DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
6603 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
6604 outputDType = rng.choice(wrong_dtypes)
6605 else:
6606 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01006607
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006608 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006609
6610 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01006611 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006612
6613 output_shape = a.shape.copy()
6614
6615 for i in range(len(output_shape)):
6616 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
6617
Matthew Haddone807aae2021-10-11 18:12:58 +01006618 # Fix negative output shape if error_if test causes it
6619 if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
6620 output_shape = [i if i >= 1 else 1 for i in output_shape]
6621
6622 if error_name == ErrorIf.WrongOutputType:
6623 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6624 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6625 outputDType = rng.choice(wrong_dtypes)
6626 else:
6627 outputDType = a.dtype
6628
6629 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006630
6631 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01006632 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006633 output_shape = shape.copy()
6634
6635 totalElements = 1
6636 for i in a.shape:
6637 totalElements *= i
6638
6639 # If there are any -1 elements, figure out what that dimension must be
6640 totalOutputElements = 1
6641 for i in output_shape:
6642 if i != -1:
6643 totalOutputElements *= i
6644
6645 # And fill it in
6646 for i in range(len(output_shape)):
6647 if output_shape[i] == -1:
6648 output_shape[i] = totalElements // totalOutputElements
6649
Matthew Haddone807aae2021-10-11 18:12:58 +01006650 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
6651 for i in range(len(output_shape)):
6652 output_shape[i] = output_shape[i] + rng.integers(1, 10)
6653
6654 if error_name == ErrorIf.WrongOutputType:
6655 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6656 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6657 outputDType = rng.choice(wrong_dtypes)
6658 else:
6659 outputDType = a.dtype
6660
6661 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006662
6663 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01006664 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006665
Matthew Haddone807aae2021-10-11 18:12:58 +01006666 if error_name == ErrorIf.WrongOutputType:
6667 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6668 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6669 outputDType = rng.choice(wrong_dtypes)
6670 else:
6671 outputDType = a.dtype
6672
6673 if error_name == ErrorIf.SizeOutputShapeMismatch:
6674 output_shape = size.copy()
6675 for index in range(len(output_shape)):
6676 if output_shape[index] <= 2:
6677 output_shape[index] = output_shape[index] + rng.choice([1, 2])
6678 else:
6679 output_shape[index] = output_shape[index] + rng.choice([-2, -1, 1, 2])
6680 else:
6681 output_shape = size.copy()
6682
6683 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006684
6685 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006686 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006687
6688 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08006689 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07006690
6691 for i in range(len(output_shape)):
6692 output_shape[i] = a.shape[i] * multiples[i]
6693
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006694 if error_name == ErrorIf.WrongOutputType:
6695 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6696 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6697 outputDType = rng.choice(wrong_dtypes)
6698 else:
6699 outputDType = a.dtype
6700
6701 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006702
6703 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01006704 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006705 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01006706
Kevin Cheng550ccc52021-03-03 11:21:43 -08006707 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07006708
Matthew Haddone807aae2021-10-11 18:12:58 +01006709 if error_name == ErrorIf.IndexOutsideBounds:
6710 for i in range(len(output_shape)):
6711 output_shape[i] = a.shape[0]
6712 else:
6713 for i in range(len(output_shape)):
6714 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07006715
Matthew Haddone807aae2021-10-11 18:12:58 +01006716 if error_name == ErrorIf.WrongOutputType:
6717 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6718 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6719 outputDType = rng.choice(wrong_dtypes)
6720 else:
6721 outputDType = a.dtype
6722
6723 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006724
6725 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006726 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006727 if error_name != ErrorIf.WrongRank:
6728 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08006729 assert len(indices.shape) == 2
6730 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07006731
Kevin Cheng77d0f762020-11-24 10:26:32 -08006732 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
6733
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006734 if error_name == ErrorIf.WrongOutputType:
6735 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6736 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
6737 outputDType = rng.choice(wrong_dtypes)
6738 else:
6739 outputDType = values.dtype
6740
6741 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08006742
6743 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006744 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006745 if error_name != ErrorIf.WrongRank:
6746 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08006747 assert len(indices.shape) == 2
6748 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08006749 assert values_in.shape[0] == indices.shape[0] # N
6750 assert input.shape[1] == indices.shape[1] # W
6751 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08006752
6753 output_shape = values_in.shape
6754
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006755 if error_name == ErrorIf.WrongOutputType:
6756 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6757 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
6758 outputDType = rng.choice(wrong_dtypes)
6759 else:
6760 outputDType = values_in.dtype
6761
6762 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006763
6764 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006765 def tableOp(ser, rng, input, error_name=None):
6766 # Same shape as the input, dtype dependent on input dtype
6767 if error_name != ErrorIf.WrongInputType:
6768 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00006769 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006770 if error_name == ErrorIf.WrongOutputType:
6771 wrong_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6772 wrong_dtypes.remove(output_dtype)
6773 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01006774 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006775
6776 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08006777 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01006778 serializer,
6779 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08006780 input,
6781 mode,
6782 stride,
6783 offset,
6784 shift,
6785 stride_fp,
6786 offset_fp,
6787 output_dims,
6788 input_dtype,
6789 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01006790 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08006791 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01006792 if error_name == ErrorIf.WrongRank:
6793 output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
6794 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01006795 if error_name == ErrorIf.BatchMismatch:
6796 output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
6797 elif error_name == ErrorIf.ChannelMismatch:
6798 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
6799 else:
6800 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07006801
Matthew Haddon693ba9e2021-09-22 11:24:37 +01006802 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006803
6804 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006805 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08006806 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006807
6808 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00006809 def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None):
Kevin Cheng3a478572021-01-22 17:21:02 -08006810 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006811 out_dtype = DType.INT32
6812 elif ifm.dtype == DType.INT16:
6813 out_dtype = DType.INT48
6814 elif ifm.dtype == DType.FLOAT:
6815 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00006816 elif error_name == ErrorIf.WrongInputType:
6817 # Pick some potentially correct output dtype if input type is incorrect
6818 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006819 else:
Les Bell0e027d42021-11-09 14:42:14 +00006820 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
6821
6822 if error_name == ErrorIf.WrongOutputType:
6823 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
6824 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07006825
Kevin Cheng550ccc52021-03-03 11:21:43 -08006826 return ser.addOutput(output_shape, out_dtype)