blob: 655cdfc87100a154f73102a7a64ab7019165903e [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
Eric Kunzee5e26762020-10-13 16:11:07 -070017import numpy as np
18import argparse
19import sys
20import re
21import os
22import subprocess
23import shlex
24import json
25import glob
26import math
27import queue
28import threading
29import traceback
30import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010031import itertools
Matthew Haddon630c17c2021-10-14 15:05:41 +010032from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
Les Bell0e027d42021-11-09 14:42:14 +000048from tosa.DType import DType
49from tosa.Op import Op
50from tosa.ResizeMode import ResizeMode
Eric Kunzee5e26762020-10-13 16:11:07 -070051
Matthew Haddon630c17c2021-10-14 15:05:41 +010052
Les Bell0e027d42021-11-09 14:42:14 +000053def valueToName(item, value):
54 """Get the name of an attribute with the given value.
55
56 This convenience function is needed to print meaningful names for
57 the values of the tosa.Op.Op and tosa.DType.DType classes.
58 This would not be necessary if they were subclasses of Enum, or
59 IntEnum, which, sadly, they are not.
60
61 Args:
62 item: The class, or object, to find the value in
63 value: The value to find
64
65 Example, to get the name of a DType value:
66
67 name = valueToName(DType, DType.INT8) # returns 'INT8'
68 name = valueToName(DType, 4) # returns 'INT8'
69
70 Returns:
71 The name of the first attribute found with a matching value,
72
73 Raises:
74 ValueError if the value is not found
75 """
76 for attr in dir(item):
77 if getattr(item, attr) == value:
78 return attr
79 raise ValueError(f'value ({value}) not found')
80
81def allDTypes(*, excludes=None):
82 """Get a set of all DType values, optionally excluding some values.
83
84 This convenience function is needed to provide a sequence of DType values.
85 This would be much easier if DType was a subclass of Enum, or IntEnum,
86 as we could then iterate over the values directly, instead of using
87 dir() to find the attributes and then check if they are what we want.
88
89 Args:
90 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
91
92 Returns:
93 A set of DType values
94 """
95 excludes = () if not excludes else excludes
96 return {getattr(DType, t) for t in dir(DType)
97 if not callable(getattr(DType, t)) and not t.startswith('__')
98 and getattr(DType, t) not in excludes}
99
100def usableDTypes(*, excludes=None):
101 """Get a set of usable DType values, optionally excluding some values.
102
103 Excludes (DType.UNKNOWN, DType.UINT8) in addition to the excludes
104 specified by the caller, as the serializer lib does not support them.
105 If you wish to include 'UNKNOWN' or 'UINT8' use allDTypes instead.
106
107 Args:
108 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
109
110 Returns:
111 A set of DType values
112 """
113 omit = {DType.UNKNOWN, DType.UINT8}
114 omit.update(excludes if excludes else ())
115 return allDTypes(excludes=omit)
116
Matthew Haddon630c17c2021-10-14 15:05:41 +0100117def product(shape):
118 value = 1
119 for n in shape:
120 value *= n
121 return value
122
Les Bell0e027d42021-11-09 14:42:14 +0000123
Eric Kunzee5e26762020-10-13 16:11:07 -0700124class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800125 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
126
Eric Kunzee5e26762020-10-13 16:11:07 -0700127 def __init__(self):
128 pass
129
130 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100131 def getQinfo(testGen, dtype, error_name=None):
132
Les Bell30e46802021-07-23 09:43:31 +0100133 if dtype == DType.INT8:
134 return testGen.randInt(-128, 128)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100135 elif dtype == DType.UINT8:
Les Bell30e46802021-07-23 09:43:31 +0100136 return testGen.randInt(0, 256)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100137 elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.WeightZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100138 zero_point = testGen.randInt(-128, 128)
139 if zero_point == 0:
140 zero_point = 1
141 return zero_point
Les Bell30e46802021-07-23 09:43:31 +0100142 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700143
144 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100145 def qgUnary(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700146 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100147 if error_name == ErrorIf.InputZeroPointNotZero:
148 qinfo.UnaryQuantInfo(
149 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype)
150 )
151 elif error_name == ErrorIf.OutputZeroPointNotZero:
152 qinfo.UnaryQuantInfo(
153 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype, error_name)
154 )
155 else:
156 qinfo.UnaryQuantInfo(
157 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
158 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700159 return qinfo
160
161 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100162 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700163 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +0100164 if isinstance(dtype_or_dtypeList, list):
165 # a list of [input, weights, accumulator] dtypes
166 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -0700167 else:
Les Bell30e46802021-07-23 09:43:31 +0100168 # an int, [input, weights, accumulator] dtypes are the same
169 dtypeList = [dtype_or_dtypeList] * 3
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100170
171 if error_name == ErrorIf.InputZeroPointNotZero:
172 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
173 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
174 elif error_name == ErrorIf.WeightZeroPointNotZero:
175 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
176 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
177 else:
178 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
179 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
180
Les Bell30e46802021-07-23 09:43:31 +0100181 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -0700182 return qinfo
183
184 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100185 def qgMatmul(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700186 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100187 if error_name == ErrorIf.InputZeroPointNotZero:
188 qinfo.MatMulQuantInfo(
189 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype, error_name)
Kevin Chengacb550f2021-06-29 15:32:19 -0700190 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100191 else:
192 qinfo.MatMulQuantInfo(
193 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
194 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700195 return qinfo
196
197 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100198 def qgPad(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700199 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone807aae2021-10-11 18:12:58 +0100200 if error_name == ErrorIf.InputZeroPointNotZero:
201 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name))
202 else:
203 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700204 return qinfo
205
206 @staticmethod
207 def computeMultiplierAndShift(scaleFp, scale32):
208 # Derived from computeMultiplierAndShiftTosaScale32
209 # Provide a floating-point scaling factor and the scale32 parameter
210 # to compute the multiplier and shift
211
212 if scale32:
213 scaleBits = 31
214 else:
215 scaleBits = 15
216
217 m, shift = math.frexp(scaleFp)
218
219 if scaleFp < 0.0:
220 m = -m
221
222 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800223 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700224
225 if multiplier == (1 << scaleBits):
226 multiplier = multiplier // 2
227 shift = shift + 1
228
229 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100230 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
231
232 # Adjust multiplier such that shift is in allowed value range.
233 if shift == 0:
234 multiplier = multiplier // 4
235 shift = shift + 2
236 elif shift == 1:
237 multiplier = multiplier // 2
238 shift = shift + 1
239 elif shift == 63:
240 multiplier = multiplier * 2
241 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700242
Kevin Cheng550ccc52021-03-03 11:21:43 -0800243 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100244 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700245
246 return multiplier, shift
247
248
Kevin Cheng550ccc52021-03-03 11:21:43 -0800249class TosaTensorGen:
250 """Tensor generators create a shape list for the placeholder and const tensor
251 data operands for the operator. The actual random data is generated separately for each test."""
252
Eric Kunzee5e26762020-10-13 16:11:07 -0700253 def __init__(self):
254 pass
255
256 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100257 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800258 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700259 shape = testGen.makeShape(rank)
260
Matthew Haddon630c17c2021-10-14 15:05:41 +0100261 # Constrict the overall size of the shape when creating ERROR_IF tests
262 if error_name:
263 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
Matthew Haddonc2025212021-10-08 21:21:05 +0100264
Eric Kunzee5e26762020-10-13 16:11:07 -0700265 shape_list = []
266 for i in range(pl + const):
267 shape_list.append(shape.copy())
268
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100269 if error_name == ErrorIf.RankMismatch:
270 if rank == 1 and i != 1:
271 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
272 elif i != 1:
273 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
274
Eric Kunzee5e26762020-10-13 16:11:07 -0700275 return shape_list
276
277 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100278 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800279 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700280
Matthew Haddon848efb42021-09-09 12:30:53 +0100281 if error_name != ErrorIf.WrongRank:
282 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700283
284 shape = testGen.makeShape(rank)
285
286 # Constrict the batch size?
287 if testGen.args.max_batch_size:
288 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100289
Matthew Haddon630c17c2021-10-14 15:05:41 +0100290 # Constrict the overall size of the shape when creating ERROR_IF tests
Jeremy Johnson27cf5432021-11-16 11:12:17 +0000291 if error_name and error_name != ErrorIf.MaxDimExceeded:
Matthew Haddon630c17c2021-10-14 15:05:41 +0100292 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
Eric Kunzee5e26762020-10-13 16:11:07 -0700293
294 shape_list = []
295 for i in range(pl + const):
296 shape_list.append(shape.copy())
297
298 return shape_list
299
300 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100301 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800302 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800303
Kevin Cheng550ccc52021-03-03 11:21:43 -0800304 assert pl == 2
305 assert const == 0
Jeremy Johnson3ca02a72021-11-18 12:18:39 +0000306 if error_name != ErrorIf.WrongRank:
307 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800308
309 values_in_shape = testGen.makeShape(rank)
310
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100311 # ignore max batch size if target shape is set
312 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800313 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
314
Kevin Cheng550ccc52021-03-03 11:21:43 -0800315 W = testGen.randInt(
316 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
317 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100318 # Constrict W if one dimension is too large to keep tensor size reasonable
319 if max(values_in_shape) > 5000:
320 W = testGen.randInt(0, 16)
321
Kevin Cheng77d0f762020-11-24 10:26:32 -0800322 input_shape = [values_in_shape[0], W, values_in_shape[2]]
323
324 shape_list = []
325 shape_list.append(values_in_shape.copy())
326 shape_list.append(input_shape.copy())
327
328 return shape_list
329
330 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100331 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700332 shape = testGen.makeShape(rank)
333
Kevin Cheng550ccc52021-03-03 11:21:43 -0800334 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700335
336 shape_list = []
337
338 # Choose one of the inputs to broadcast
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +0000339 # Note: Simplifies OutputShaper code if we don't change first shape for errors
340 bcast_idx = testGen.randInt(0 if error_name == None else 1, pl + const)
Eric Kunzee5e26762020-10-13 16:11:07 -0700341 for i in range(pl + const):
342 shape_bcast = shape.copy()
343
344 # If the chosen input, pick a random index to broadcast
345 if i == bcast_idx:
346 fuzz_idx = testGen.randInt(0, rank)
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +0000347 if error_name == ErrorIf.DimensionMismatch:
348 shape_bcast[fuzz_idx] += 1
349 elif error_name == ErrorIf.RankMismatch:
350 # Add one rank to the shape (or more for rank of 1)
351 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
352 shape_bcast = np.concatenate((shape_bcast, testGen.makeShape(extra_ranks)))
353 if rank != 1:
354 # Either keep the extra rank, or remove it
355 new_len = testGen.rng.choice([-2, len(shape_bcast)])
356 shape_bcast = shape_bcast[:new_len]
357 else:
358 shape_bcast[fuzz_idx] = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700359
360 shape_list.append(shape_bcast)
361
362 return shape_list
363
364 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100365 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800366 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700367
Les Bell0e027d42021-11-09 14:42:14 +0000368 if error_name != ErrorIf.WrongRank:
369 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700370
371 # IFM dimensions are NHWC
372 ifm_shape = testGen.makeShape(rank)
373
374 # Constrict the batch size?
375 if testGen.args.max_batch_size:
376 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
377
Les Bell0e027d42021-11-09 14:42:14 +0000378 # Constrict the overall size of the shape when creating ERROR_IF tests
379 if error_name:
380 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
381
Eric Kunzee5e26762020-10-13 16:11:07 -0700382 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800383 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700384
385 # Generate a random OFM depth
386 ofm_depth = testGen.makeShape(1)[0]
387
388 # The filter dimensions are OHWI
389 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
390
391 # The bias is OC
392 bias_shape = np.asarray([ofm_depth])
393
394 return [ifm_shape, filter_shape, bias_shape]
395
396 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100397 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700398 pl, const = op["operands"]
399
Les Bell0e027d42021-11-09 14:42:14 +0000400 if error_name != ErrorIf.WrongRank:
401 assert rank == 5
Kevin Cheng1533b852021-09-01 12:51:58 -0700402
403 # IFM dimensions are NDHWC
404 ifm_shape = testGen.makeShape(rank)
405
406 # Constrict the batch size?
407 if testGen.args.max_batch_size:
408 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
409
Les Bell0e027d42021-11-09 14:42:14 +0000410 # Constrict the overall size of the shape when creating ERROR_IF tests
411 if error_name:
412 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
413
Kevin Cheng1533b852021-09-01 12:51:58 -0700414 # Get the filter depth/height/width from the operator parameters
415 filter_dhw = op["filter"]
416
417 # Generate a random OFM channel
418 ofm_channel = testGen.makeShape(1)[0]
419
420 # The filter dimensions are ODHWI
421 filter_shape = np.asarray(
422 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
423 )
424
425 # The bias is OC
426 bias_shape = np.asarray([ofm_channel])
427
428 return [ifm_shape, filter_shape, bias_shape]
429
430 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100431 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800432 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700433
Les Bell0e027d42021-11-09 14:42:14 +0000434 if error_name != ErrorIf.WrongRank:
435 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700436
437 # IFM dimensions are NHWC
438 ifm_shape = testGen.makeShape(rank)
439
440 # Constrict the batch size?
441 if testGen.args.max_batch_size:
442 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
443
Les Bell0e027d42021-11-09 14:42:14 +0000444 # Constrict the overall size of the shape when creating ERROR_IF tests
445 if error_name:
446 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
447
Eric Kunzee5e26762020-10-13 16:11:07 -0700448 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800449 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700450
451 # Generate a random OFM depth
452 ofm_depth = testGen.makeShape(1)[0]
453
454 # The filter dimensions are OHWI
455 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
456
Kevin Cheng989cb052021-04-28 16:29:44 -0700457 # The bias is OC
458 bias_shape = np.asarray([ofm_depth])
459
460 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700461
462 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100463 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800464 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700465
Les Bell0e027d42021-11-09 14:42:14 +0000466 if error_name != ErrorIf.WrongRank:
467 assert rank == 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800468 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700469
470 # IFM dimensions are NHWC
471 ifm_shape = testGen.makeShape(rank)
472
473 # Constrict the batch size?
474 if testGen.args.max_batch_size:
475 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
476
Les Bell0e027d42021-11-09 14:42:14 +0000477 # Constrict the overall size of the shape when creating ERROR_IF tests
478 if error_name:
479 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
480
Eric Kunzee5e26762020-10-13 16:11:07 -0700481 # Get the filter height/width from the operator parameters
482 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800483 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700484
485 # Generate a random OFM depth, but don't let it get too big because
486 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800487 filter_m = (
488 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
489 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700490
491 # The filter dimensions are HWCM
492 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
493
494 # The bias is M * C
495 bias_shape = np.asarray([ifm_shape[3] * filter_m])
496
497 return [ifm_shape, filter_shape, bias_shape]
498
499 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100500 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800501 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700502
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100503 if error_name != ErrorIf.WrongRank:
504 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700505
506 input_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100507
Matthew Haddon630c17c2021-10-14 15:05:41 +0100508 # Constrict the overall size of the shape when creating ERROR_IF tests
509 if error_name:
Les Bell0e027d42021-11-09 14:42:14 +0000510 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100511
Kevin Chengacb550f2021-06-29 15:32:19 -0700512 filter_oc = testGen.rng.integers(
513 low=testGen.args.tensor_shape_range[0],
514 high=testGen.args.tensor_shape_range[1],
515 size=1,
516 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700517 filter_shape = np.asarray([filter_oc, input_shape[1]])
518
519 bias_shape = np.asarray([filter_oc])
520
521 return [input_shape, filter_shape, bias_shape]
522
523 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100524 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800525 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700526
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100527 if error_name != ErrorIf.WrongRank:
528 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800529 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700530
531 a_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100532
Matthew Haddon630c17c2021-10-14 15:05:41 +0100533 # Constrict the overall size of the shape when creating ERROR_IF tests
534 if error_name:
Les Bell0e027d42021-11-09 14:42:14 +0000535 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100536
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100537 # Get a random number for b_oc even if target shape is defined
538 b_oc = np.int32(
539 testGen.rng.integers(
540 low=testGen.args.tensor_shape_range[0],
541 high=testGen.args.tensor_shape_range[1],
542 size=1,
543 )
544 )[0]
545 # If N or H is large let b_oc be 1 to reduce output tensor size
546 if max(a_shape) > 1000:
547 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700548
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100549 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700550 return [a_shape, b_shape]
551
Matthew Haddon818ab902021-07-27 09:12:49 +0100552 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100553 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100554 pl, const = opName["operands"]
555 shape = testGen.makeShape(rank)
556
557 # Create extra tensors to concat.
558 # Take into account value of pl when getting maximum number of concats
559 num_tensors = testGen.randInt(0, 4)
560 shape_list = []
561 for i in range(pl + const + num_tensors):
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100562 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
563 remove = testGen.rng.choice([True, False])
564 wrongShape = shape.copy()
565
566 if remove and len(shape) > 1:
567 wrongShape = wrongShape[1:]
568 else:
569 wrongShape = list(wrongShape)
570 wrongShape.append(testGen.rng.integers(1, 10))
571
572 shape_list.append(wrongShape)
573 else:
574 shape_list.append(shape.copy())
Matthew Haddon818ab902021-07-27 09:12:49 +0100575
576 return shape_list
577
578 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100579 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100580 if error_name in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ConcatInputRankMismatch]:
581 return shapeList
582
Matthew Haddon818ab902021-07-27 09:12:49 +0100583 # Split concat shape along axis to allow for multiple const inputs
584 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100585 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100586 # If axis can't be split we still need to invalidate other dimensions
587 if error_name == ErrorIf.ConcatInputDimMismatch:
588 for shape in shapeList[1:]:
589 # Negative test shapeLists are created individually for each test,
590 # so no need to copy the shape before altering it.
591 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
Matthew Haddon818ab902021-07-27 09:12:49 +0100592 return shapeList
593
Jeremy Johnson960985a2021-10-06 10:58:14 +0100594 # Create copy of shape we are going to split (so we don't alter shapeList)
595 shape = shapeList[0].copy()
596 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100597 new_shapeList = [shape.copy()]
598 length_on_axis = shape[axis]
599 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700600 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100601 # Calculate split on axis and remaining value
602 split_shape_val = int(shape[axis] / 2)
603 remaining_length = remaining_length - split_shape_val
604
605 # Append new shape, and set remaining shape
606 shape[axis] = split_shape_val
607 new_shapeList.append(shape.copy())
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100608
609 # invalidate dimensions
610 if error_name == ErrorIf.ConcatInputDimMismatch:
611 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
612 else:
613 shape[axis] = remaining_length
614
Matthew Haddon818ab902021-07-27 09:12:49 +0100615 if i == len(shapeList) - 3:
616 new_shapeList.append(shape.copy())
617
618 return new_shapeList
619
620
Eric Kunzee5e26762020-10-13 16:11:07 -0700621class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800622 """Argument generators create exhaustive or random lists of attributes for operators that take
623 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
624 tuples where the descriptive_name is appended to the test name and the arglist is expanded
625 as arguments to the operator build function."""
626
Eric Kunzee5e26762020-10-13 16:11:07 -0700627 def __init__(self):
628 pass
629
630 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100631 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800632 """A trivial argument generator for operators that don't take any
633 non-tensor arguments"""
634 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700635
636 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100637 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800638 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700639 axes = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700640 shape = shapeList[0]
641
Matthew Haddond6ce7252021-09-29 15:35:44 +0100642 if error_name == ErrorIf.AxisSmallerZero:
643 small_axis = testGen.rng.integers(-5, 0)
644 axes.append(("axis{}".format(small_axis), [small_axis]))
645 elif error_name == ErrorIf.AxisLargerRank:
646 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
647 axes.append(("axis{}".format(large_axis), [large_axis]))
648 else:
649 for a in range(0, len(shape)):
650 axes.append(("axis{}".format(a), [a]))
651
Eric Kunzee5e26762020-10-13 16:11:07 -0700652 return axes
653
654 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100655 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700656 arg_list = []
657
658 ifm_shape = shapeList[0]
659 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100660 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
661 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700662
Les Bell7aa69f42021-09-20 10:44:07 +0100663 # Check the rank
664 rank = 5 if opName.startswith("conv3d") else 4
Les Bell0e027d42021-11-09 14:42:14 +0000665 if error_name != ErrorIf.WrongRank:
666 assert len(ifm_shape) == rank
667 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700668
Les Bell7aa69f42021-09-20 10:44:07 +0100669 # kernel rank omits batch and channels
670 k_rank = rank - 2
Les Bell0e027d42021-11-09 14:42:14 +0000671 assert len(k) == k_rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700672
Les Bell7aa69f42021-09-20 10:44:07 +0100673 # Generate comprehensive argument lists
Les Bell0e027d42021-11-09 14:42:14 +0000674 # - except for named errors, which use specific invalid value(s)
675 if error_name == ErrorIf.PadSmallerZero:
676 p_vals = [testGen.rng.choice(range(-5, 0))]
677 else:
678 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100679 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
Les Bell0e027d42021-11-09 14:42:14 +0000680 if error_name == ErrorIf.StrideSmallerOne:
681 # Can't use stride=0, as it is used to derive output shape, as a divisor
682 s_vals = [testGen.rng.choice(range(-5, 0))]
683 else:
684 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100685 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
Les Bell0e027d42021-11-09 14:42:14 +0000686 if error_name == ErrorIf.DilationSmallerOne:
687 d_vals = [testGen.rng.choice(range(-5, 1))]
688 else:
689 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100690 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700691
Les Bell0e027d42021-11-09 14:42:14 +0000692 if not error_name:
693 # add some oversize argument values
694 if max(ifm_shape) < 64:
695 bigPadding = 9
696 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
697 bigStride = 8
698 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
699 bigDilation = 7
700 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100701
Les Bell0e027d42021-11-09 14:42:14 +0000702 # There are too many parameter combinations, so generate them sparsely,
703 # very sparse for negative tests
704 sparsity_factor = 2 if error_name else 100
705 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
706 # If there are only a small number of tests, just select them all
Les Bell7aa69f42021-09-20 10:44:07 +0100707 if sparsity < 13:
708 sparsity = 1
Les Bell0e027d42021-11-09 14:42:14 +0000709 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
Les Bell7aa69f42021-09-20 10:44:07 +0100710 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
711 sparsity += 1
Les Bell0e027d42021-11-09 14:42:14 +0000712
Les Bellf414b3c2021-09-06 11:29:46 +0100713 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100714 for s in sorted(list(strides)):
715 for p in sorted(list(paddings)):
716 for d in sorted(list(dilations)):
717 if (n % sparsity == 0
718 # padding must not exceed the kernel size ?
719 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
720 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
721 # the padded shape must exceed the kernel size
722 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
723 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
724 # the padded shape must exceed the dilation
725 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
726 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
727 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100728 arg_list.append(
729 (
730 "st{}_pad{}_dilat{}".format(
731 "".join([str(x) for x in s]),
732 "".join([str(x) for x in p]),
733 "".join([str(x) for x in d]),
734 ),
735 [s, p, d],
736 )
737 )
738 n += 1
739
Kevin Cheng1533b852021-09-01 12:51:58 -0700740 return arg_list
741
742 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100743 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700744 arg_list = []
745
746 ifm_shape = shapeList[0]
747 filter_shape = shapeList[1]
748
749 # Must be rank 4
Les Bell0e027d42021-11-09 14:42:14 +0000750 if error_name != ErrorIf.WrongRank:
751 assert len(ifm_shape) == 4
752 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700753
Les Bell7aa69f42021-09-20 10:44:07 +0100754 # Generate comprehensive argument lists
Les Bell0e027d42021-11-09 14:42:14 +0000755 # - except for named errors, which use specific invalid value(s)
756 if error_name == ErrorIf.PadSmallerZero:
757 p_vals = [testGen.rng.choice(range(-5, 0))]
758 else:
759 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100760 paddings = {x for x in itertools.product(*([p_vals] * 2))}
Les Bell0e027d42021-11-09 14:42:14 +0000761 if error_name == ErrorIf.StrideSmallerOne:
762 # Can't use stride=0, as it is used to derive output shape, as a divisor
763 s_vals = [testGen.rng.choice(range(-5, 0))]
764 else:
765 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100766 strides = {x for x in itertools.product(*([s_vals] * 2))}
Les Bell0e027d42021-11-09 14:42:14 +0000767 if error_name == ErrorIf.DilationSmallerOne:
768 d_vals = [testGen.rng.choice(range(-5, 1))]
769 else:
770 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100771 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700772
Les Bell0e027d42021-11-09 14:42:14 +0000773 if not error_name:
774 # add some oversize argument values
775 if max(ifm_shape) < 64:
776 bigPadding = 9
777 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
778 bigStride = 8
779 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
780 bigDilation = 7
781 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700782
Les Bell0e027d42021-11-09 14:42:14 +0000783 # There are too many parameter combinations, so generate them sparsely,
784 # very sparse for negative tests
785 sparsity_factor = 2 if error_name else 100
786 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
787 # If there are only a small number of tests, just select them all
Les Bell7aa69f42021-09-20 10:44:07 +0100788 if sparsity < 13:
789 sparsity = 1
Les Bell0e027d42021-11-09 14:42:14 +0000790 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
Les Bell7aa69f42021-09-20 10:44:07 +0100791 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
792 sparsity += 1
Les Bell0e027d42021-11-09 14:42:14 +0000793
Les Bell7aa69f42021-09-20 10:44:07 +0100794 n = 0
795 for s in sorted(list(strides)):
796 for p in sorted(list(paddings)):
797 for d in sorted(list(dilations)):
798 if n % sparsity == 0:
799 # Determine the output shape
800 oh = (
801 ifm_shape[1]
802 - filter_shape[1]
803 - (filter_shape[1] - 1) * (d[0] - 1)
804 + 2 * p[0]
805 ) // s[0] + 1
806 ow = (
807 ifm_shape[2]
808 - filter_shape[2]
809 - (filter_shape[2] - 1) * (d[1] - 1)
810 + 2 * p[1]
811 ) // s[1] + 1
812 os = [ifm_shape[0], oh, ow, filter_shape[0]]
813 arg_list.append(
814 (
815 "st{}_pad{}_dilat{}_os{}".format(
816 "".join([str(x) for x in s]),
817 "".join([str(x) for x in p]),
818 "".join([str(x) for x in d]),
819 "x".join([str(x) for x in os]),
820 ),
821 [s, p, d, os],
822 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800823 )
Les Bell7aa69f42021-09-20 10:44:07 +0100824 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700825
826 return arg_list
827
828 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100829 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700830 arg_list = []
831 rank = len(shapeList[0])
832
Les Bell7ffccce2021-07-28 15:37:02 +0100833 # Exhaustively test combinations of padding on each side of each dimension
834 # - the range of padding values is defined by pad_min and pad_max
835 # - for padding >9, the name format needs to be more distinctive
836 pad_min, pad_max = 0, 1
837 pad_values = [x for x in range(pad_min, pad_max + 1)]
Matthew Haddone807aae2021-10-11 18:12:58 +0100838 if error_name == ErrorIf.PadSmallerZero:
839 pad_values = [x for x in range(-2, 0)]
Les Bell7ffccce2021-07-28 15:37:02 +0100840 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
841 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700842
Kevin Chengfe392ce2021-10-18 21:51:55 +0000843 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
844 pad_const_int = testGen.getRandNumberDType(dtype)
845 pad_const_fp = 0
846 elif dtype == DType.FLOAT:
847 pad_const_int = 0
848 pad_const_fp = testGen.getRandNumberDType(dtype)
849 else:
850 return []
851
Les Bell7ffccce2021-07-28 15:37:02 +0100852 for paddings in shape_pad_values:
853 name = "pad"
854 for r in range(rank):
855 before, after = paddings[r]
856 name = f"{name}{before}{after}"
Kevin Chengfe392ce2021-10-18 21:51:55 +0000857 arg_list.append((name, [np.array(paddings), pad_const_int, pad_const_fp]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700858
859 return arg_list
860
861 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100862 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700863 arg_list = []
864
865 shape = shapeList[0]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100866 if error_name != ErrorIf.WrongRank:
867 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700868
Les Bell7aa69f42021-09-20 10:44:07 +0100869 # Generate comprehensive argument lists
870 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
871 paddings = {x for x in itertools.product(*([p_vals] * 4))}
872 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
873 strides = {x for x in itertools.product(*([s_vals] * 2))}
874 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
875 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700876
Les Bell7aa69f42021-09-20 10:44:07 +0100877 # add some oversize argument values
878 bigStride = 7
879 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
880 bigKernel = 6
881 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
882 if max(shape) < 64:
883 # padding must be less than the kernel size
884 bigPadding = bigKernel - 1
885 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700886
Les Bell0e027d42021-11-09 14:42:14 +0000887 # There are too many parameter combinations, so generate them sparsely,
888 # very sparse for negative tests
889 sparsity_factor = 2 if error_name else 500
890 sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
891
Les Bell7aa69f42021-09-20 10:44:07 +0100892 n = 0
893 for s in sorted(list(strides)):
894 for p in sorted(list(paddings)):
895 for k in sorted(list(kernels)):
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100896 if error_name in [ErrorIf.StrideSmallerOne, ErrorIf.KernelSmallerOne, ErrorIf.PadSmallerZero, ErrorIf.PadLargerEqualKernel]:
897 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(testGen, error_name, s, p, k)
898 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
899 arg_list.append(
900 (
901 "st{}_kern{}_pad{}".format(
902 "".join([str(x) for x in sNew]),
903 "".join([str(x) for x in kNew]),
904 "".join([str(x) for x in pNew]),
905 ),
906 [sNew, pNew, kNew],
907 )
908 )
909 elif (n % sparsity == 0
Les Bell7aa69f42021-09-20 10:44:07 +0100910 # padding must not exceed the kernel size
911 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
912 # the padded shape must exceed the kernel size
913 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
914 ):
915 arg_list.append(
916 (
917 "st{}_kern{}_pad{}".format(
918 "".join([str(x) for x in s]),
919 "".join([str(x) for x in k]),
920 "".join([str(x) for x in p]),
921 ),
922 [s, p, k],
923 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800924 )
Les Bell7aa69f42021-09-20 10:44:07 +0100925 n += 1
926
Eric Kunzee5e26762020-10-13 16:11:07 -0700927 return arg_list
928
929 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100930 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700931 arg_list = []
932
933 # Enumerate the output types here
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100934 if error_name == ErrorIf.WrongOutputType:
935 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
936 elif inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800937 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700938 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800939 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700940 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800941 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700942 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800943 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700944 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800945 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100946 elif error_name == ErrorIf.WrongInputType:
947 # Pick some potentially correct output type for incorrect input type
948 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700949 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800950 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700951
952 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800953 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700954
955 return arg_list
956
957 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100958 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700959 arg_list = []
960
961 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100962 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100963 if dtype in [DType.UINT8, DType.INT8] and error_name == ErrorIf.OutputZeroPointNotZero:
964 continue
965 if inDtype == DType.UINT8 and dtype != DType.INT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100966 # The only output dtype for UINT8 is INT8, skip all other combinations
967 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100968 if inDtype != DType.INT8 and dtype == DType.UINT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100969 # The only input dtype for UINT8 is INT8, skip all other combinations
970 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100971 if error_name == ErrorIf.WrongOutputType and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype):
972 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100973
Kevin Cheng550ccc52021-03-03 11:21:43 -0800974 for scale32 in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100975 if error_name == ErrorIf.ScaleTrue and scale32 == False:
976 continue
977 elif error_name == ErrorIf.ScaleNotTrue and scale32 == True:
978 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800979 for double_round in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100980 if error_name == ErrorIf.ScaleNotTrue and double_round == False:
981 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800982 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700983
Matthew Haddonc2025212021-10-08 21:21:05 +0100984 if inDtype == DType.INT48 and scale32 and error_name != ErrorIf.ScaleTrue:
Eric Kunzee5e26762020-10-13 16:11:07 -0700985 # Illegal condition. Must be scale32=False
986 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100987 if double_round and not scale32 and error_name != ErrorIf.ScaleNotTrue:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100988 # Illegal condition. ERROR_IF(!scale32 && double_round)
989 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700990
Kevin Cheng550ccc52021-03-03 11:21:43 -0800991 arg_list.append(
992 (
993 "out{}_sc{}_dr{}_pc{}".format(
994 DTypeNames[dtype],
995 int(scale32),
996 int(double_round),
997 int(per_channel),
998 ),
999 [dtype, scale32, double_round, per_channel],
1000 )
1001 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001002
1003 return arg_list
1004
Kevin Chengaee1fac2020-11-11 13:54:06 -08001005 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001006 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -08001007 arg_list = []
1008
1009 if dtype is DType.INT32:
1010 for p in range(testGen.args.num_rand_permutations):
1011
1012 shift = testGen.randInt(0, 32)
1013
Kevin Cheng550ccc52021-03-03 11:21:43 -08001014 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001015 else:
Matthew Haddon43e37192021-07-09 14:13:02 +01001016 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001017
1018 return arg_list
1019
1020 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001021 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -08001022 arg_list = []
1023
Kevin Cheng550ccc52021-03-03 11:21:43 -08001024 arg_list.append(("roundTrue", [True]))
1025 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001026
1027 return arg_list
1028
Eric Kunzee5e26762020-10-13 16:11:07 -07001029 # Helper function for reshape. Gets some factors of a larger number.
1030 @staticmethod
1031 def getFactors(val, start=1):
1032 factors = []
1033
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001034 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -07001035 if (val % i) == 0:
1036 factors.append(i)
1037
1038 return factors
1039
1040 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001041 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001042 arg_list = []
1043
1044 origShape = shapeList[0]
1045
1046 totalElements = 1
1047 for s in origShape:
1048 totalElements *= s
1049
1050 # This code is NOT fast. Fortunately, the numbers are fairly small.
1051 factors = TosaArgGen.getFactors(totalElements)
1052
1053 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +01001054 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001055 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -07001056 continue
1057
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001058 found = True
1059 # escape_counter breaks while loop if it continues on for too long
1060 escape_counter = 0
1061 while found:
1062 newShape = []
1063 # Generate newShape ensuring it isn't a duplicate
1064 remainingElements = totalElements
1065 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +01001066 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001067 # pick rank-1 factors
1068 newShape.append(shuffledFactors[0])
1069 remainingElements = remainingElements // shuffledFactors[0]
1070 shuffledFactors = testGen.rng.permutation(
1071 TosaArgGen.getFactors(remainingElements)
1072 )
1073 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -07001074
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001075 # Toss in a -1 sometimes
1076 minusOne = testGen.randInt(0, newRank * 4)
1077 if minusOne < newRank:
1078 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -07001079
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001080 # Check for duplicates
1081 found = False
1082 for name, other_shape in arg_list:
1083 if other_shape[0] == newShape:
1084 found = True
1085 break
1086
1087 escape_counter += 1
1088 if escape_counter >= 100:
1089 break
1090
1091 if not found:
1092 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001093
1094 return arg_list
1095
Eric Kunzee5e26762020-10-13 16:11:07 -07001096 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001097 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001098 arg_list = []
1099
1100 ifm_shape = shapeList[0]
1101
Matthew Haddone807aae2021-10-11 18:12:58 +01001102
1103 if error_name == ErrorIf.IndexOutsideBounds:
1104 incorrect_large_index = range(len(ifm_shape)+1, 2*len(ifm_shape)+1)
1105 incorrect_small_index = range(-len(ifm_shape), 0)
1106 permutations = [p for p in itertools.permutations(incorrect_large_index)]
1107 permutations.extend([p for p in itertools.permutations(incorrect_small_index)])
1108 elif error_name == ErrorIf.IndexUsedTwice:
1109 # Create list with a duplicated index
1110 perm_range = list(range(len(ifm_shape)))
1111 index_choice = testGen.rng.choice(range(len(perm_range)))
1112 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
1113 permutations = [p for p in itertools.permutations(perm_range)]
1114
1115
1116 else:
1117 # Get all permutations
1118 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -07001119
Jeremy Johnsona6185572021-06-21 15:55:35 +01001120 # Limit to possible permutations from shape dimension or argument setting
1121 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001122
Jeremy Johnsona6185572021-06-21 15:55:35 +01001123 # Get random permutation generator that uses all permutations
1124 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001125
Jeremy Johnsona6185572021-06-21 15:55:35 +01001126 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -07001127 arg_list = [
1128 ("perm{}".format(p), [random_permutations[p].tolist()])
1129 for p in range(limit)
1130 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07001131 return arg_list
1132
1133 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001134 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001135 arg_list = []
1136
1137 ifm_shape = shapeList[0]
1138 rank = len(ifm_shape)
1139
1140 for p in range(testGen.args.num_rand_permutations):
Matthew Haddone807aae2021-10-11 18:12:58 +01001141 start = []
Eric Kunzee5e26762020-10-13 16:11:07 -07001142 size = []
1143
Kevin Cheng550ccc52021-03-03 11:21:43 -08001144 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -07001145
1146 for i in range(rank):
1147 if ifm_shape[i] > 1:
Matthew Haddone807aae2021-10-11 18:12:58 +01001148 start.append(testGen.randInt(0, ifm_shape[i]))
1149 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001150
1151 # Invalid slice size?
1152 if size[i] == 0:
1153 valid = False
1154 else:
Matthew Haddone807aae2021-10-11 18:12:58 +01001155 start.append(0)
Eric Kunzee5e26762020-10-13 16:11:07 -07001156 size.append(1)
1157
1158 if valid:
Matthew Haddone807aae2021-10-11 18:12:58 +01001159 # If ERROR_IF test required then incorrect start, size will be returned
1160 start, size = TosaErrorIfArgGen.eiSliceErrorIf(testGen, error_name, ifm_shape, start, size)
1161 arg_list.append(("perm{}".format(p), [start, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001162 return arg_list
1163
1164 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001165 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001166 arg_list = []
1167
1168 ifm_shape = shapeList[0]
1169 rank = len(ifm_shape)
1170
1171 for p in range(testGen.args.num_rand_permutations):
1172
1173 # Pick a few random, but small multiple values
1174 # because otherwise this has a tendency to generate
1175 # enormous tensors
1176 multiples = []
1177 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +01001178 if ifm_shape[i] > 1000:
1179 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
1180 multiples.append(1)
1181 elif max(ifm_shape) > 1000:
1182 multiples.append(2)
1183 else:
1184 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001185 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001186
1187 return arg_list
1188
1189 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001190 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001191 arg_list = []
1192
1193 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +01001194 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001195
1196 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +01001197 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001198 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +01001199 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001200 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +01001201 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001202 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +01001203 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001204 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001205 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001206 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +01001207 elif error_name == ErrorIf.WrongInputType:
1208 # If an incorrect input type is used then we set a 'correct'
1209 # output type to avoid other errors
1210 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -07001211 else:
1212 continue
1213
1214 for outputDType in outputDTypeList:
1215 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -07001216 # Randomly generate legal output dimensions and shift
1217 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +01001218 # A output_dim of 1 will cause offset to exceed allowed range
1219 # so minimum value 2 produced below
1220 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
1221 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
1222 output_dims[0] += 1
1223 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
1224 output_dims[1] += 1
1225
Kevin Cheng77d0f762020-11-24 10:26:32 -08001226 in_center_h = (ifm_shape[1] - 1) / 2.0
1227 in_center_w = (ifm_shape[2] - 1) / 2.0
1228 out_center_h = (output_dims[0] - 1) / 2.0
1229 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -07001230
Kevin Cheng77d0f762020-11-24 10:26:32 -08001231 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
1232 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
1233 fp_offset_y = in_center_h - fp_stride_y * out_center_h
1234 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -07001235
Kevin Cheng77d0f762020-11-24 10:26:32 -08001236 if outputDType == DType.FLOAT:
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001237 float_op = True
1238 arg_str = "mode{}_shift{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}"
Kevin Cheng77d0f762020-11-24 10:26:32 -08001239 shift = 0
1240 stride = [0, 0]
1241 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001242 stride_fp = [fp_stride_y, fp_stride_x]
1243 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +01001244
Kevin Cheng77d0f762020-11-24 10:26:32 -08001245 else:
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001246 float_op = False
1247 arg_str = "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}"
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001248 shift = testGen.randInt(1,12)
1249 # Now search for a shift value (1 to 11) that will produce
1250 # a valid and predictable resize operation
1251 count = 0
1252 while (count < 12):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001253 unit = float(1 << shift)
1254 stride_y = int(round(fp_stride_y * unit))
1255 stride_x = int(round(fp_stride_x * unit))
1256 offset_y = int(round(fp_offset_y * unit))
1257 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001258
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001259 if (
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001260 stride_y <= 0
1261 or stride_x <= 0
1262 or stride_y >= (16 << shift)
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001263 or stride_x >= (16 << shift)
1264 or offset_y >= (16 << shift)
1265 or offset_x >= (16 << shift)
1266 or offset_y <= (-16 << shift)
1267 or offset_x <= (-16 << shift)
1268 ):
1269 # Change the shift value and check again
1270 count += 1
1271 shift = (shift % 11) + 1
1272 continue
1273
1274 def RESIZE_REQUIRE_CALC(length_in, length_out, stride, offset, shift):
1275 # Perform the pseudo loop to look for out of bounds
1276 for pos in range(0,length_out):
1277 a = pos * stride + offset
1278 ia = a >> shift
1279 ia0 = max(ia, 0)
1280 ia1 = min(ia+1, length_in-1)
1281 if ia0 > ia1:
1282 # Found a problem value
1283 break
1284 return ia0, ia1
1285
1286 iy0, iy1 = RESIZE_REQUIRE_CALC(ifm_shape[1], output_dims[0], stride_y, offset_y, shift)
1287 ix0, ix1 = RESIZE_REQUIRE_CALC(ifm_shape[2], output_dims[1], stride_x, offset_x, shift)
1288 if ix0 > ix1 or iy0 > iy1:
1289 # Change the shift value and check again
1290 count += 1
1291 shift = (shift % 11) + 1
1292 continue
1293 break
1294
1295 if count >= 12:
1296 # Couldn't find a good set of values for this test, skip it
1297 continue
1298
Kevin Cheng550ccc52021-03-03 11:21:43 -08001299 stride = [stride_y, stride_x]
1300 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001301
1302 stride_fp = [0.0, 0.0]
1303 offset_fp = [0.0, 0.0]
1304
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001305 # Common for all data types
1306 if error_name is not None:
1307 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
1308 testGen,
1309 error_name,
1310 mode,
1311 dtype,
1312 shapeList,
1313 outputDType,
1314 shift,
1315 stride,
1316 stride_fp,
1317 offset,
1318 offset_fp
Kevin Cheng550ccc52021-03-03 11:21:43 -08001319 )
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001320 else:
1321 outputDTypeNew = outputDType
1322
1323 arg_list.append(
1324 (
1325 arg_str.format(
1326 "N" if mode == ResizeMode.NEAREST else "B",
1327 shift,
1328 output_dims[0],
1329 output_dims[1],
1330 testGen.typeStr(outputDTypeNew),
1331 stride_fp[0] if float_op else stride[0],
1332 stride_fp[1] if float_op else stride[1],
1333 offset_fp[0] if float_op else offset[0],
1334 offset_fp[1] if float_op else offset[1]
1335 ),
1336 [
1337 mode,
1338 stride,
1339 offset,
1340 shift,
1341 stride_fp,
1342 offset_fp,
1343 output_dims,
1344 dtype,
1345 outputDTypeNew,
1346 ],
1347 )
1348 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001349
1350 return arg_list
1351
Kevin Chengfe392ce2021-10-18 21:51:55 +00001352 @staticmethod
1353 def agTable(testGen, opName, shapeList, dtype, error_name=None):
1354 arg_list = []
1355
1356 if dtype == DType.INT8:
1357 table = np.int32(
1358 testGen.rng.integers(low=-128, high=128, size=[256])
1359 ).tolist()
1360 else: # INT16
1361 table = np.int32(
1362 testGen.rng.integers(low=-32768, high=32768, size=[513])
1363 ).tolist()
1364
1365 arg_list.append(
1366 (
1367 "",
1368 [table],
1369 )
1370 )
1371 return arg_list
1372
Matthew Haddon1c00b712021-10-01 15:51:03 +01001373 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001374 # CondIf generates the condition values here.
1375 # Convert to tensors in the build function, along with the
1376 # then and else blocks
1377 arg_list = []
1378
1379 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001380 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001381
1382 return arg_list
1383
Matthew Haddon1c00b712021-10-01 15:51:03 +01001384 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001385 # While loop: 0 iterations, 1, more than 1
1386 arg_list = []
1387
1388 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001389 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001390
1391 return arg_list
1392
Matthew Haddone86fd342021-09-07 16:12:21 +01001393class TosaErrorIfArgGen:
1394
1395 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001396 def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
Matthew Haddone86fd342021-09-07 16:12:21 +01001397
1398 if outputDType == DType.FLOAT:
1399 if error_name == ErrorIf.StrideSmallerEqualZero:
1400 stride_fp = testGen.rng.random(size=[2]) - 2
1401 elif error_name == ErrorIf.ShiftNotZero:
1402 shift = testGen.rng.integers(1, 5)
1403 elif error_name == ErrorIf.StrideLargerDimension:
1404 shape = shapeList[0]
1405 transform_height = testGen.rng.choice([False, True])
1406 if transform_height:
1407 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1408 else:
1409 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1410 else:
1411 if error_name == ErrorIf.StrideSmallerEqualZero:
1412 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1413 elif error_name == ErrorIf.ShiftSmallerOne:
1414 shift = testGen.rng.integers(-3, 1)
1415 if shift <= 0:
1416 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1417 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1418 else:
1419 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1420 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1421 elif error_name == ErrorIf.ShiftLargerEleven:
1422 shift = np.int16(testGen.rng.integers(12, 15))
1423 elif error_name == ErrorIf.StrideLargerDimension:
1424 shape = shapeList[0]
1425 transform_height = testGen.rng.choice([False, True])
1426 if transform_height:
1427 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1428 else:
1429 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1430 elif error_name == ErrorIf.StrideLargerEqualMax:
1431 stride = [(16 << shift) + 1, (16 << shift) + 1]
1432 elif error_name == ErrorIf.OffsetLargerEqualMax:
1433 offset = [(16 << shift) + 1, (16 << shift) + 1]
1434 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1435 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1436
Matthew Haddon1c00b712021-10-01 15:51:03 +01001437
Matthew Haddon848efb42021-09-09 12:30:53 +01001438 if error_name == ErrorIf.WrongOutputType:
1439 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1440 incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1441 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1442 incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
1443 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1444 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
1445 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1446 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
1447 elif dtype == DType.FLOAT:
1448 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
1449 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001450
Matthew Haddon848efb42021-09-09 12:30:53 +01001451 return shift, stride, stride_fp, offset, offset_fp, outputDType
1452
Matthew Haddone807aae2021-10-11 18:12:58 +01001453
Matthew Haddon848efb42021-09-09 12:30:53 +01001454 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001455 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
1456 if (error_name == ErrorIf.StrideSmallerOne
1457 # padding must not exceed the kernel size
1458 and pad[0] < kernel[0] and pad[1] < kernel[0] and pad[2] < kernel[1] and pad[3] < kernel[1]):
1459 wrongStride = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1460 return wrongStride, pad, kernel
1461 elif error_name == ErrorIf.PadSmallerZero:
1462 wrongPad = (testGen.rng.choice([-1, -2, -3]),
1463 testGen.rng.choice([-1, -2, -3]),
1464 testGen.rng.choice([-1, -2, -3]),
1465 testGen.rng.choice([-1, -2, -3]))
1466 return stride, wrongPad, kernel
1467 elif error_name == ErrorIf.KernelSmallerOne:
1468 wrongKernel = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1469 return stride, pad, wrongKernel
1470 elif error_name == ErrorIf.PadLargerEqualKernel:
1471 wrongPad = (testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1472 testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1473 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]),
1474 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]))
1475 return stride, wrongPad, kernel
1476 else:
1477 return None, None, None
1478
Matthew Haddone807aae2021-10-11 18:12:58 +01001479
Matthew Haddonc2025212021-10-08 21:21:05 +01001480 @staticmethod
1481 def eiRescaleWrongOutputType(input_dtype, output_dtype):
1482 if input_dtype == DType.INT8:
1483 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1484 return True
1485 if input_dtype in [DType.INT16, DType.INT32]:
1486 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1487 return True
1488 elif input_dtype == DType.INT48:
1489 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1490 return True
1491 elif input_dtype == DType.UINT8:
1492 if output_dtype != DType.INT8:
1493 return True
1494 return False
1495
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001496
1497 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001498 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1499 # Mess up input/output tensors for ERROR_IF checks
1500 if error_name == "WrongInputList":
1501 add_input = testGen.rng.choice([True, False])
1502 if add_input:
1503 input_list.append('eiDummyInput')
1504 else:
1505 input_list = input_list[:-1]
Les Bell0e027d42021-11-09 14:42:14 +00001506 elif error_name == "WrongOutputList":
Matthew Haddon848efb42021-09-09 12:30:53 +01001507 add_output = testGen.rng.choice([True, False])
1508 if add_output:
1509 output_list.append('eiDummyOutput')
1510 else:
1511 output_list = []
1512 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001513
Matthew Haddonc2025212021-10-08 21:21:05 +01001514 @staticmethod
Matthew Haddon630c17c2021-10-14 15:05:41 +01001515 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
1516 """Restrict the dimensions and overall size of a shape to max_dim and max_items."""
1517 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
1518 while product(new_shape) > max_items:
1519 new_shape = [max(d - 1, 1) for d in new_shape]
1520 return new_shape
Matthew Haddone807aae2021-10-11 18:12:58 +01001521
1522 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
1523 if error_name == ErrorIf.StartSmallerZero:
1524 newStart = []
1525 for i in range(len(input_shape)):
1526 newStart.append(testGen.rng.choice([-3, -2, -1]))
1527 return newStart, size
1528 elif error_name == ErrorIf.SizeSmallerEqualZero:
1529 newSize = []
1530 for i in range(len(input_shape)):
1531 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
1532 return start, newSize
1533 elif error_name == ErrorIf.StartSizeOutsideBounds:
1534 newStart, newSize = [], []
1535 for i in range(len(input_shape)):
1536 newStart.append(input_shape[i]-1)
1537 newSize.append(testGen.rng.choice([2, 3, 4]))
1538 return newStart, newSize
1539 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
1540 remove = testGen.rng.choice([True, False])
1541 if remove:
1542 newStart = start[1:]
1543 newSize = size[1:]
1544 else:
1545 newStart = start
1546 newStart.append(1)
1547 newSize = size
1548 newSize.append(1)
1549 return newStart, newSize
1550 else:
1551 return start, size
1552
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001553 @staticmethod
1554 def eiCastErrorIf(testGen, input_dtype):
1555 if input_dtype in [DType.BOOL, DType.FLOAT]:
1556 outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
1557 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
1558 outputDType = [DType.INT48]
1559 else:
1560 assert True, f"input_dtype ({input_dtype}) not supported"
1561 return outputDType
1562
1563
Matthew Haddone86fd342021-09-07 16:12:21 +01001564class TosaErrorValidator:
1565
Matthew Haddon848efb42021-09-09 12:30:53 +01001566 @staticmethod
1567 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
Les Bell729b0352021-11-24 10:28:21 +00001568 """Check ERROR_IF statements are caught and set the expected result.
1569
1570 Args:
1571 serializer: the serializer to set the expected result in
1572 validator_fcns: a sequence of validator functions to verify the result
1573 error_name: the name of the ERROR_IF condition to check for
1574 kwargs: keyword arguments for the validator functions
1575 Returns:
1576 True if the result matches the expected result; otherwise False
1577 """
1578 overall_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001579 for val_fcn in validator_fcns:
1580 val_result = val_fcn(True, **kwargs)
Matthew Haddon848efb42021-09-09 12:30:53 +01001581 validator_name = val_result['error_name']
1582 error_result = val_result['error_result']
1583 error_reason = val_result['error_reason']
1584
Les Bell0e027d42021-11-09 14:42:14 +00001585 # expect an error IFF the error_name and validator_name match
1586 expected_result = error_result == (error_name == validator_name)
Les Bell729b0352021-11-24 10:28:21 +00001587 overall_result &= expected_result
Les Bell0e027d42021-11-09 14:42:14 +00001588
1589 if expected_result and error_result:
1590 serializer.setExpectedReturnCode(2, error_reason)
1591 elif error_result: # and not expected_result
1592 print(f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
1593 f" Expected: {error_name}, Got: {validator_name}")
1594 elif not expected_result: # and not error_result
1595 print(f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
1596 f" Expected: {error_name}")
1597
1598 if not expected_result:
1599 for k, v in sorted(kwargs.items()):
1600 if k != 'op':
1601 if k.endswith('dtype'):
1602 v = valueToName(DType, v)
1603 print(f' {k} = {v}')
Matthew Haddon848efb42021-09-09 12:30:53 +01001604
Les Bell729b0352021-11-24 10:28:21 +00001605 return overall_result
1606
Matthew Haddon848efb42021-09-09 12:30:53 +01001607 @staticmethod
1608 def evWrongInputType(check=False, **kwargs):
Les Bell0e027d42021-11-09 14:42:14 +00001609 error_result = False
Matthew Haddon848efb42021-09-09 12:30:53 +01001610
1611 # Find the unsupported input data types
Matthew Haddon848efb42021-09-09 12:30:53 +01001612 op = kwargs['op']
1613 input_dtypes = op['types']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001614 allowed_input_dtypes = {t[0] if isinstance(t, list) else t for t in input_dtypes}
Les Bell0e027d42021-11-09 14:42:14 +00001615 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
Matthew Haddon848efb42021-09-09 12:30:53 +01001616
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001617 if op['op'] == Op.CLAMP:
1618 wrong_input_dtypes.remove(DType.INT48)
1619
Matthew Haddon848efb42021-09-09 12:30:53 +01001620 if check:
1621 input_dtype = kwargs['input_dtype']
Les Bell0e027d42021-11-09 14:42:14 +00001622 if input_dtype not in allowed_input_dtypes:
Matthew Haddon848efb42021-09-09 12:30:53 +01001623 error_result = True
1624
1625 info_dict = {
Les Bell0e027d42021-11-09 14:42:14 +00001626 "error_name": ErrorIf.WrongInputType,
Matthew Haddon848efb42021-09-09 12:30:53 +01001627 "error_result": error_result,
Les Bell0e027d42021-11-09 14:42:14 +00001628 "error_reason": f"Input data type not supported for this operator",
1629 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
Matthew Haddon848efb42021-09-09 12:30:53 +01001630 }
1631 return info_dict
1632
1633 @staticmethod
1634 def evWrongOutputType(check=False, **kwargs):
Matthew Haddon848efb42021-09-09 12:30:53 +01001635 error_result = False
Matthew Haddon848efb42021-09-09 12:30:53 +01001636
1637 if check:
1638 input_dtype = kwargs['input_dtype']
1639 output_dtype = kwargs['output_dtype']
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001640 op = kwargs['op']
Matthew Haddon848efb42021-09-09 12:30:53 +01001641
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001642 if op['op'] == Op.RESIZE:
1643 mode = kwargs['mode']
1644 if (
1645 (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
1646 (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
1647 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1648 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1649 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1650 ):
1651 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001652
Matthew Haddonc2025212021-10-08 21:21:05 +01001653 elif op['op'] == Op.RESCALE:
1654 if input_dtype == DType.INT8:
1655 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1656 error_result = True
1657 if input_dtype in [DType.INT16, DType.INT32]:
1658 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1659 error_result = True
1660 elif input_dtype == DType.INT48:
1661 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1662 error_result = True
1663 elif input_dtype == DType.UINT8:
1664 if output_dtype != DType.INT8:
1665 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001666
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001667 elif op['op'] in [Op.FULLY_CONNECTED, Op.MATMUL]:
1668 if (
1669 (input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1670 (input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1671 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1672 ):
1673 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001674
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001675 elif op['op'] == Op.ARGMAX:
1676 if input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] and output_dtype != DType.INT32:
1677 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001678
1679 elif op['op'] == Op.MUL:
1680 if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
1681 error_result = True
1682 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1683 error_result = True
1684
1685 elif op['op'] == Op.TABLE:
1686 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
1687 error_result = True
1688 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
1689 error_result = True
1690
1691 elif op['op'] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
1692 if output_dtype != DType.BOOL:
1693 error_result = True
1694
1695 elif op['op'] == Op.CAST:
1696 if (
1697 (input_dtype == DType.BOOL and output_dtype not in [DType.INT8, DType.INT16, DType.INT32])
1698 or (input_dtype == DType.INT8 and output_dtype not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT])
1699 or (input_dtype == DType.INT16 and output_dtype not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT])
1700 or (input_dtype == DType.INT32 and output_dtype not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT])
1701 or (input_dtype == DType.FLOAT and output_dtype not in [DType.INT8, DType.INT16, DType.INT32])
1702 ):
1703 error_result = True
1704
Les Bell0e027d42021-11-09 14:42:14 +00001705 elif op['op'] in {Op.CONV2D, Op.CONV3D, Op.DEPTHWISE_CONV2D, Op.TRANSPOSE_CONV2D}:
1706 if (
1707 input_dtype == DType.INT8 and output_dtype != DType.INT32
1708 or input_dtype == DType.INT16 and output_dtype != DType.INT48
1709 or input_dtype == DType.FLOAT and output_dtype != DType.FLOAT
1710 ):
1711 error_result = True
1712 # invalid input types are ignored, to avoid reporting multiple errors
1713
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001714 else:
1715 if output_dtype != input_dtype:
1716 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001717
1718 info_dict = {
Les Bell0e027d42021-11-09 14:42:14 +00001719 "error_name": ErrorIf.WrongOutputType,
Matthew Haddon848efb42021-09-09 12:30:53 +01001720 "error_result": error_result,
Les Bell0e027d42021-11-09 14:42:14 +00001721 "error_reason": "Output data type not supported for this configuration of operator",
1722 "param_reqs": {"rank": None, "dtype": None, "shape": None}
Matthew Haddon848efb42021-09-09 12:30:53 +01001723 }
1724 return info_dict
1725
1726 @staticmethod
1727 def evWrongRank(check=False, **kwargs):
1728 all_ranks = (1, 2, 3, 4, 5)
1729
1730 # Make a list of incorrect ranks
1731 assert 'op' in kwargs
1732 op = kwargs['op']
1733 rmin, rmax = op['rank']
1734 rank_range = range(rmin, rmax + 1)
1735 incorrect_ranks = list(set(all_ranks) - set(rank_range))
Matthew Haddonc2025212021-10-08 21:21:05 +01001736 # Remove small incorrect ranks to avoid index errors
1737 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
Matthew Haddon848efb42021-09-09 12:30:53 +01001738 # Set minimum incorrect rank to 3 to avoid index error
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001739 if op['op'] in [Op.RESIZE]:
Matthew Haddon848efb42021-09-09 12:30:53 +01001740 incorrect_ranks = [3, 5]
Les Bell0e027d42021-11-09 14:42:14 +00001741 elif op['op'] in [Op.TRANSPOSE]:
Matthew Haddon01c359d2021-10-15 16:30:48 +01001742 incorrect_ranks = [7, 8]
Les Bell0e027d42021-11-09 14:42:14 +00001743 elif op['op'] in [Op.CONV3D]:
1744 incorrect_ranks = [6, 7]
Matthew Haddon848efb42021-09-09 12:30:53 +01001745
1746 error_name = ErrorIf.WrongRank
1747 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1748 error_result = False
1749 error_reason = "Rank not supported for this operator"
1750
1751 if check:
1752 input_shape = kwargs['input_shape']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001753
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001754 if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4:
Matthew Haddon848efb42021-09-09 12:30:53 +01001755 error_result = True
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001756 elif op['op'] == Op.FULLY_CONNECTED and len(input_shape) != 2:
1757 error_result = True
1758 elif op['op'] == Op.MATMUL and len(input_shape) != 3:
1759 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001760 else:
1761 if len(input_shape) not in rank_range:
1762 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001763
1764 info_dict = {
1765 "error_name": error_name,
1766 "error_result": error_result,
1767 "error_reason": error_reason,
1768 "param_reqs": param_reqs
1769 }
1770 return info_dict
1771
1772 @staticmethod
1773 def evWrongInputList(check=False, **kwargs):
1774 error_name = ErrorIf.WrongInputList
1775 param_reqs = {"rank": None, "dtype": None, "shape": None}
1776 error_result = False
1777 error_reason = "Op input list does not match expected input"
1778
1779 if check:
1780 op = kwargs['op']
1781 input_list = kwargs['input_list']
1782 num_operands = kwargs['num_operands']
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001783 if op['op'] in [Op.SCATTER, Op.GATHER]:
1784 # SCATTER/GATHER add an indices input tensor in their build functions
1785 num_operands += 1
Kevin Chengfe392ce2021-10-18 21:51:55 +00001786 if len(input_list) != num_operands:
1787 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001788
1789 info_dict = {
1790 "error_name": error_name,
1791 "error_result": error_result,
1792 "error_reason": error_reason,
1793 "param_reqs": param_reqs
1794 }
1795 return info_dict
1796
1797 @staticmethod
1798 def evWrongOutputList(check=False, **kwargs):
1799 error_name = ErrorIf.WrongOutputList
1800 param_reqs = {"rank": None, "dtype": None, "shape": None}
1801 error_result = False
1802 error_reason = "Op output list does not match expected output"
1803
1804 if check:
1805 output_list = kwargs['output_list']
1806 # Note this will be incorrect if an operator returns more than one output
1807 if len(output_list) != 1:
1808 error_result = True
1809
1810 info_dict = {
1811 "error_name": error_name,
1812 "error_result": error_result,
1813 "error_reason": error_reason,
1814 "param_reqs": param_reqs
1815 }
1816 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01001817
1818 @staticmethod
1819 def evMaxDimExceeded(check=False, **kwargs):
1820 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01001821 param_reqs = {
1822 "rank": [4,4],
1823 "dtype": [DType.INT8],
1824 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
1825 }
Matthew Haddone86fd342021-09-07 16:12:21 +01001826 error_result = False
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001827 error_reason = "At least one maximum dimension is greater than or equal to 16384"
Matthew Haddone86fd342021-09-07 16:12:21 +01001828
1829 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001830 input_shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001831 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001832 if ((input_shape[1] >= 16384) or
1833 (input_shape[2] >= 16384) or
1834 (output_shape[0] >= 16384) or
1835 (output_shape[1] >= 16384)):
Matthew Haddone86fd342021-09-07 16:12:21 +01001836 error_result = True
1837
1838 info_dict = {
1839 "error_name": error_name,
1840 "error_result": error_result,
1841 "error_reason": error_reason,
1842 "param_reqs": param_reqs
1843 }
1844 return info_dict
1845
1846 @staticmethod
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001847 def evBatchMismatch(check=False, **kwargs):
1848 error_name = ErrorIf.BatchMismatch
1849 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1850 error_result = False
1851 error_reason = "Input batch size not equal to output batch size"
1852
1853 assert 'op' in kwargs
1854 op = kwargs['op']
1855 rmin, rmax = op['rank']
1856 rank_range = range(rmin, rmax + 1)
1857
1858 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001859 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001860 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1861
1862 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
1863 error_result = True
1864
1865 info_dict = {
1866 "error_name": error_name,
1867 "error_result": error_result,
1868 "error_reason": error_reason,
1869 "param_reqs": param_reqs
1870 }
1871 return info_dict
1872
1873 @staticmethod
1874 def evChannelMismatch(check=False, **kwargs):
1875 error_name = ErrorIf.ChannelMismatch
1876 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1877 error_result = False
1878 error_reason = "Input channel size not equal to output channel size"
1879
1880 assert 'op' in kwargs
1881 op = kwargs['op']
1882 rmin, rmax = op['rank']
1883 rank_range = range(rmin, rmax + 1)
1884
1885 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001886 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001887 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1888 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
1889 error_result = True
1890
1891 info_dict = {
1892 "error_name": error_name,
1893 "error_result": error_result,
1894 "error_reason": error_reason,
1895 "param_reqs": param_reqs
1896 }
1897 return info_dict
1898
1899 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001900 def evStrideSmallerEqualZero(check=False, **kwargs):
1901 error_name = ErrorIf.StrideSmallerEqualZero
1902 param_reqs = {"rank": None, "dtype": None, "shape": None}
1903 error_result = False
1904 error_reason = "Stride value smaller than or equal zero"
1905
1906 if check:
1907 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001908 output_dtype = kwargs['output_dtype']
1909 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
1910 stride = kwargs['stride'] # Work around wrong input/output type tests
1911 elif output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001912 stride = kwargs['stride_fp']
Matthew Haddon848efb42021-09-09 12:30:53 +01001913 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1914 stride = kwargs['stride_fp'] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01001915 else:
1916 stride = kwargs['stride']
1917
1918 if min(stride) <= 0:
1919 error_result = True
1920
1921 info_dict = {
1922 "error_name": error_name,
1923 "error_result": error_result,
1924 "error_reason": error_reason,
1925 "param_reqs": param_reqs
1926 }
1927 return info_dict
1928
1929 @staticmethod
1930 def evStrideLargerEqualMax(check=False, **kwargs):
1931 error_name = ErrorIf.StrideLargerEqualMax
1932 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1933 error_result = False
1934 error_reason = "Stride value larger than or equal to maximum value"
1935
1936 if check:
1937 shift = kwargs['shift']
1938 input_dtype = kwargs['input_dtype']
1939 stride = kwargs['stride']
1940 if input_dtype in [DType.INT8, DType.INT16]:
1941 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1942 error_result = True
1943 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1944 error_result = True
1945
1946 info_dict = {
1947 "error_name": error_name,
1948 "error_result": error_result,
1949 "error_reason": error_reason,
1950 "param_reqs": param_reqs
1951 }
1952 return info_dict
1953
1954
1955 @staticmethod
1956 def evStrideLargerDimension(check=False, **kwargs):
1957 error_name = ErrorIf.StrideLargerDimension
1958 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1959 error_result = False
1960 error_reason = "Stride value larger than or equal to H/W dimension"
1961
1962 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001963 shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001964 input_dtype = kwargs['input_dtype']
1965 stride = kwargs['stride_fp']
1966
1967 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1968 error_result = True
1969
1970 info_dict = {
1971 "error_name": error_name,
1972 "error_result": error_result,
1973 "error_reason": error_reason,
1974 "param_reqs": param_reqs
1975 }
1976 return info_dict
1977
1978
1979 @staticmethod
1980 def evOffsetSmallerEqualMin(check=False, **kwargs):
1981 error_name = ErrorIf.OffsetSmallerEqualMin
1982 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1983 error_result = False
1984 error_reason = "Offset value smaller than or equal to minimum value"
1985
1986 if check:
1987 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001988 output_dtype = kwargs['output_dtype']
1989 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001990 offset = kwargs['offset_fp']
1991 else:
1992 offset = kwargs['offset']
1993
1994 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
1995 error_result = True
1996 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
1997 error_result = True
1998
1999 info_dict = {
2000 "error_name": error_name,
2001 "error_result": error_result,
2002 "error_reason": error_reason,
2003 "param_reqs": param_reqs
2004 }
2005 return info_dict
2006
2007 @staticmethod
2008 def evOffsetLargerEqualMax(check=False, **kwargs):
2009 error_name = ErrorIf.OffsetLargerEqualMax
2010 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2011 error_result = False
2012 error_reason = "Offset value larger than or equal to maximum value"
2013
2014 if check:
2015 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01002016 output_dtype = kwargs['output_dtype']
2017 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01002018 offset = kwargs['offset_fp']
2019 else:
2020 offset = kwargs['offset']
2021
2022 if shift >= 0:
2023 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
2024 error_result = True
2025
2026 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
2027 error_result = True
2028 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
2029 error_result = True
2030
2031 info_dict = {
2032 "error_name": error_name,
2033 "error_result": error_result,
2034 "error_reason": error_reason,
2035 "param_reqs": param_reqs
2036 }
2037 return info_dict
2038
2039 @staticmethod
2040 def evShiftNotZero(check=False, **kwargs):
2041 error_name = ErrorIf.ShiftNotZero
2042 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
2043 error_result = False
2044 error_reason = "Shift value must be zero for float input"
2045
2046 if check:
2047 shift = kwargs['shift']
2048 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01002049 output_dtype = kwargs['output_dtype']
2050 if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
Matthew Haddone86fd342021-09-07 16:12:21 +01002051 error_result = True
2052
2053 info_dict = {
2054 "error_name": error_name,
2055 "error_result": error_result,
2056 "error_reason": error_reason,
2057 "param_reqs": param_reqs
2058 }
2059 return info_dict
2060
2061
2062 @staticmethod
2063 def evShiftSmallerOne(check=False, **kwargs):
2064 error_name = ErrorIf.ShiftSmallerOne
2065 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2066 error_result = False
2067 error_reason = "Shift value smaller than one"
2068
2069 if check:
2070 shift = kwargs['shift']
2071 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01002072 output_dtype = kwargs['output_dtype']
2073 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01002074 error_result = True
2075
2076 info_dict = {
2077 "error_name": error_name,
2078 "error_result": error_result,
2079 "error_reason": error_reason,
2080 "param_reqs": param_reqs
2081 }
2082 return info_dict
2083
2084 @staticmethod
2085 def evShiftLargerEleven(check=False, **kwargs):
2086 error_name = ErrorIf.ShiftLargerEleven
2087 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2088 error_result = False
2089 error_reason = "Shift value larger than eleven"
2090
2091 if check:
2092 shift = kwargs['shift']
2093 if shift > 11:
2094 error_result = True
2095
2096 info_dict = {
2097 "error_name": error_name,
2098 "error_result": error_result,
2099 "error_reason": error_reason,
2100 "param_reqs": param_reqs
2101 }
2102 return info_dict
2103
2104
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002105 @staticmethod
2106 def evRankMismatch(check=False, **kwargs):
2107 error_name = ErrorIf.RankMismatch
2108 param_reqs = {"rank": None, "dtype": None, "shape": None}
2109 error_result = False
2110 error_reason = "Input Rank does not match output rank"
2111
2112 if check:
2113 input1_shape = kwargs['input1'].shape
2114 input2_shape = kwargs['input2'].shape
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002115 # In case of SELECT op
2116 input3_shape = kwargs['input3'].shape if 'input3' in kwargs else input2_shape
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002117 output_shape = kwargs['result_tensor'].shape
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002118 if (
2119 (len(input1_shape) != len(output_shape)) or
2120 (len(input2_shape) != len(output_shape)) or
2121 (len(input3_shape) != len(output_shape))
2122 ):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002123 error_result = True
2124
2125 info_dict = {
2126 "error_name": error_name,
2127 "error_result": error_result,
2128 "error_reason": error_reason,
2129 "param_reqs": param_reqs
2130 }
2131 return info_dict
2132
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002133 @staticmethod
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002134 def evDimensionMismatch(check=False, **kwargs):
2135 error_name = ErrorIf.DimensionMismatch
2136 param_reqs = {"rank": None, "dtype": None, "shape": None}
2137 error_result = False
2138 error_reason = "Input Dimensions do not match output"
2139
2140 if check:
2141 input1_shape = kwargs['input1'].shape
2142 input2_shape = kwargs['input2'].shape
2143 # In case of SELECT op
2144 input3_shape = kwargs['input3'].shape if 'input3' in kwargs else input2_shape
2145 output_shape = kwargs['result_tensor'].shape
2146 for i in range(min(len(input1_shape), len(input2_shape), len(input3_shape))):
2147 if (
2148 (input1_shape[i] != 1 and input1_shape[i] != output_shape[i]) or
2149 (input2_shape[i] != 1 and input2_shape[i] != output_shape[i]) or
2150 (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
2151 ):
2152 error_result = True
2153
2154 info_dict = {
2155 "error_name": error_name,
2156 "error_result": error_result,
2157 "error_reason": error_reason,
2158 "param_reqs": param_reqs
2159 }
2160 return info_dict
2161
2162 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002163 def evInputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002164 op = kwargs['op']
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002165 error_result = False
Les Bell0e027d42021-11-09 14:42:14 +00002166
2167 # Quantizable types
2168 qTypes = (DType.INT8, DType.UINT8)
2169
2170 # This does not apply to quantizable types
2171 inputDtypes = [
2172 dtype for dtype in op['types']
2173 if (isinstance(dtype, list) and dtype[0] not in qTypes) or
2174 (not isinstance(dtype, list) and dtype not in qTypes)
2175 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002176
2177 if check:
2178 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01002179 if isinstance(kwargs['qinfo'], tuple):
2180 qinfo = kwargs['qinfo']
2181 input_zero_point = qinfo[0]
2182 else:
2183 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
2184 qinfo = kwargs['qinfo'].ints
2185 input_zero_point = qinfo[0][1]
2186
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002187 if op['op'] == Op.MATMUL:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002188 qinfo = kwargs['qinfo'].ints
Les Bell0e027d42021-11-09 14:42:14 +00002189 for dtype, zp in (
2190 (kwargs['input_dtype'], qinfo[0][1]),
2191 (kwargs['input2_dtype'], qinfo[1][1]),
2192 ):
2193 if dtype not in qTypes and zp != 0:
2194 error_result = True
2195 break
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002196 else:
Les Bell0e027d42021-11-09 14:42:14 +00002197 error_result = input_dtype not in qTypes and input_zero_point != 0
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002198
2199 info_dict = {
Les Bell0e027d42021-11-09 14:42:14 +00002200 "error_name": ErrorIf.InputZeroPointNotZero,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002201 "error_result": error_result,
Les Bell0e027d42021-11-09 14:42:14 +00002202 "error_reason": "Input DType not INT8 and zero point not 0",
2203 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None}
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002204 }
2205 return info_dict
2206
2207
2208 @staticmethod
2209 def evWeightZeroPointNotZero(check=False, **kwargs):
2210 op = kwargs['op']
2211
2212 # exclude inputs with INT8 weights
2213 inputDtypes = [t for t in op['types']
2214 if not isinstance(t, list) or t[1] != DType.INT8]
2215
2216 error_name = ErrorIf.WeightZeroPointNotZero
2217 param_reqs = {
2218 "rank": None,
2219 "dtype": inputDtypes,
2220 "shape": None
2221 }
2222 error_result = False
2223 error_reason = "Weight DType not INT8 and zero point not 0"
2224
2225 if check:
2226 weight_dtype = kwargs['weight_dtype']
2227 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
2228 qinfo = kwargs['qinfo'].ints
2229 weight_zero_point = qinfo[1][1]
2230 if weight_dtype != DType.INT8 and weight_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002231 error_result = True
2232
2233 info_dict = {
2234 "error_name": error_name,
2235 "error_result": error_result,
2236 "error_reason": error_reason,
2237 "param_reqs": param_reqs
2238 }
2239 return info_dict
2240
2241
2242 @staticmethod
2243 def evOutputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002244 op = kwargs['op']
2245 inputDtypes = op['types'].copy()
2246 if DType.INT8 in inputDtypes:
2247 inputDtypes.remove(DType.INT8)
2248 if DType.UINT8 in inputDtypes:
2249 inputDtypes.remove(DType.UINT8)
2250
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002251 error_name = ErrorIf.OutputZeroPointNotZero
2252 param_reqs = {
2253 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002254 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002255 "shape": None
2256 }
2257 error_result = False
2258 error_reason = "Output DType not INT8 and zero point not 0"
2259
2260 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002261 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01002262 output_dtype = kwargs['output_dtype']
2263 if isinstance(kwargs['qinfo'], tuple):
2264 qinfo = kwargs['qinfo']
2265 output_zero_point = qinfo[1]
2266 else:
2267 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
2268 qinfo = kwargs['qinfo'].ints
2269 output_zero_point = qinfo[1][1]
2270 if op['op'] == Op.AVG_POOL2D:
2271 if input_dtype != DType.INT8 and output_zero_point != 0:
2272 error_result = True
2273 elif output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002274 error_result = True
2275
2276 info_dict = {
2277 "error_name": error_name,
2278 "error_result": error_result,
2279 "error_reason": error_reason,
2280 "param_reqs": param_reqs
2281 }
2282 return info_dict
2283
Matthew Haddond6ce7252021-09-29 15:35:44 +01002284 @staticmethod
2285 def evAxisSmallerZero(check=False, **kwargs):
2286 error_name = ErrorIf.AxisSmallerZero
2287 param_reqs = {"rank": None, "dtype": None, "shape": None}
2288 error_result = False
2289 error_reason = "Axis smaller than zero"
2290
2291 if check:
2292 axis = kwargs['axis']
2293 if axis < 0:
2294 error_result = True
2295
2296 info_dict = {
2297 "error_name": error_name,
2298 "error_result": error_result,
2299 "error_reason": error_reason,
2300 "param_reqs": param_reqs
2301 }
2302 return info_dict
2303
2304
2305 @staticmethod
2306 def evAxisLargerRank(check=False, **kwargs):
2307 error_name = ErrorIf.AxisLargerRank
2308 param_reqs = {"rank": None, "dtype": None, "shape": None}
2309 error_result = False
2310 error_reason = "Axis larger than rank"
2311
2312 if check:
2313 axis = kwargs['axis']
2314 shape = kwargs['input_shape']
2315 if axis > len(shape):
2316 error_result = True
2317
2318 info_dict = {
2319 "error_name": error_name,
2320 "error_result": error_result,
2321 "error_reason": error_reason,
2322 "param_reqs": param_reqs
2323 }
2324 return info_dict
2325
2326
2327 @staticmethod
2328 def evShapeOfAxisNotOne(check=False, **kwargs):
2329 error_name = ErrorIf.ShapeOfAxisNotOne
2330 param_reqs = {"rank": None, "dtype": None, "shape": None}
2331 error_result = False
2332 error_reason = "shape[axis] is not equal to 1"
2333
2334 if check:
2335 axis = kwargs['axis']
2336 shape = kwargs['output_shape']
2337 if (0 <= axis < len(shape)) and shape[axis] != 1:
2338 error_result = True
2339
2340 info_dict = {
2341 "error_name": error_name,
2342 "error_result": error_result,
2343 "error_reason": error_reason,
2344 "param_reqs": param_reqs
2345 }
2346 return info_dict
2347
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002348
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002349 @staticmethod
2350 def evPadSmallerZero(check=False, **kwargs):
2351 error_name = ErrorIf.PadSmallerZero
2352 param_reqs = {"rank": None, "dtype": None, "shape": None}
2353 error_result = False
2354 error_reason = "At least one pad is smaller than zero"
2355
2356 if check:
Matthew Haddone807aae2021-10-11 18:12:58 +01002357 op = kwargs['op']
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002358 pad = kwargs['pad']
Matthew Haddone807aae2021-10-11 18:12:58 +01002359 if op['op'] == Op.PAD:
2360 for padding in pad:
2361 if min(padding) < 0:
2362 error_result = True
2363 else:
2364 if min(pad) < 0:
2365 error_result = True
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002366
2367 info_dict = {
2368 "error_name": error_name,
2369 "error_result": error_result,
2370 "error_reason": error_reason,
2371 "param_reqs": param_reqs
2372 }
2373 return info_dict
2374
2375
2376 @staticmethod
2377 def evPadLargerEqualKernel(check=False, **kwargs):
2378 error_name = ErrorIf.PadLargerEqualKernel
2379 param_reqs = {"rank": None, "dtype": None, "shape": None}
2380 error_result = False
2381 error_reason = "At least one pad is larger than kernel dimension"
2382
2383 if check:
2384 pad = kwargs['pad']
2385 kernel = kwargs['kernel']
2386 if min(pad) > 0 and min(kernel) > 1:
2387 if pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]:
2388 error_result = True
2389
2390 info_dict = {
2391 "error_name": error_name,
2392 "error_result": error_result,
2393 "error_reason": error_reason,
2394 "param_reqs": param_reqs
2395 }
2396 return info_dict
2397
2398 @staticmethod
2399 def evPoolingOutputShapeMismatch(check=False, **kwargs):
2400 error_name = ErrorIf.PoolingOutputShapeMismatch
2401 param_reqs = {"rank": None, "dtype": None, "shape": None}
2402 error_result = False
2403 error_reason = "Mismatch between output shape provided and expected output shape"
2404
2405 if check:
2406 pad = kwargs['pad']
2407 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
2408
2409 kernel = kwargs['kernel']
2410 kernel_y, kernel_x = kernel[0], kernel[1]
2411
2412 input_shape = kwargs['input_shape']
2413 IH, IW = input_shape[1], input_shape[2]
2414
2415 output_shape = kwargs['output_shape']
2416 OH, OW = output_shape[1], output_shape[2]
2417
2418 stride = kwargs['stride']
2419 stride_y, stride_x = stride[0], stride[1]
2420
2421 # calculate correct height, width dimensions
2422 if stride_x != 0 and stride_y != 0:
2423 y_correct = (IH + pad_top + pad_bottom + stride_y - kernel_y) // stride_y
2424 x_correct = (IW + pad_left + pad_right + stride_x - kernel_x) // stride_x
2425
2426 # ensure parameters are valid
2427 params_valid = (min(kernel) >= 1 and min(stride) >= 1 and min(pad) >= 0
2428 and not (pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]))
2429
2430 if params_valid and (OH != y_correct or OW != x_correct):
2431 error_result = True
2432
2433 info_dict = {
2434 "error_name": error_name,
2435 "error_result": error_result,
2436 "error_reason": error_reason,
2437 "param_reqs": param_reqs
2438 }
2439 return info_dict
2440
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002441 @staticmethod
2442 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
2443 error_name = ErrorIf.ArgmaxOutputShapeMismatch
2444 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2445 error_result = False
2446 error_reason = "Mismatch between output shape provided and expected output shape"
2447
2448 if check:
2449 output_shape = kwargs['output_shape']
2450 input_shape = kwargs['input_shape']
2451 axis = kwargs['axis']
2452
2453 dimension_match = True
2454 axis_shift = 0
2455
2456 # Check that rank is correct before trying to check dimensions
2457 if (len(input_shape) - 1) == len(output_shape):
2458 for i in range(len(input_shape)):
2459 if i == axis:
2460 axis_shift = 1
2461 continue
2462 if input_shape[i] != output_shape[i - axis_shift]:
2463 dimension_match = False
2464
2465 if not dimension_match:
2466 error_result = True
2467
2468 info_dict = {
2469 "error_name": error_name,
2470 "error_result": error_result,
2471 "error_reason": error_reason,
2472 "param_reqs": param_reqs
2473 }
2474 return info_dict
2475
2476 @staticmethod
2477 def evArgmaxOutputRankMismatch(check=False, **kwargs):
2478 error_name = ErrorIf.ArgmaxOutputRankMismatch
2479 param_reqs = {"rank": None, "dtype": None, "shape": None}
2480 error_result = False
2481 error_reason = "Mismatch between output shape provided and expected output shape"
2482
2483 if check:
2484 output_shape = kwargs['output_shape']
2485 input_shape = kwargs['input_shape']
2486 axis = kwargs['axis']
2487 valid_params = axis >= 0 and axis < len(input_shape)
2488
2489 if valid_params and (len(input_shape) - 1) != len(output_shape):
2490 error_result = True
2491
2492 info_dict = {
2493 "error_name": error_name,
2494 "error_result": error_result,
2495 "error_reason": error_reason,
2496 "param_reqs": param_reqs
2497 }
2498 return info_dict
2499
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002500
2501 @staticmethod
2502 def evKernelSmallerOne(check=False, **kwargs):
2503 error_name = ErrorIf.KernelSmallerOne
2504 param_reqs = {"rank": None, "dtype": None, "shape": None}
2505 error_result = False
2506 error_reason = "At least one kernel dimension is smaller than zero"
2507
2508 if check:
2509 kernel = kwargs['kernel']
2510 if min(kernel) < 1:
2511 error_result = True
2512
2513 info_dict = {
2514 "error_name": error_name,
2515 "error_result": error_result,
2516 "error_reason": error_reason,
2517 "param_reqs": param_reqs
2518 }
2519 return info_dict
2520
2521 @staticmethod
2522 def evStrideSmallerOne(check=False, **kwargs):
2523 error_name = ErrorIf.StrideSmallerOne
2524 param_reqs = {"rank": None, "dtype": None, "shape": None}
2525 error_result = False
2526 error_reason = "At least one stride dimension is smaller than zero"
2527
2528 if check:
2529 stride = kwargs['stride']
2530 if min(stride) < 1:
2531 error_result = True
2532
2533 info_dict = {
2534 "error_name": error_name,
2535 "error_result": error_result,
2536 "error_reason": error_reason,
2537 "param_reqs": param_reqs
2538 }
2539 return info_dict
2540
Matthew Haddonc2025212021-10-08 21:21:05 +01002541 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00002542 def evDilationSmallerOne(check=False, **kwargs):
2543 error_result = check and min(kwargs['dilation']) < 1
2544 return {
2545 "error_name": ErrorIf.DilationSmallerOne,
2546 "error_reason": "At least one dilation is smaller than one",
2547 "param_reqs": {"rank": None, "dtype": None, "shape": None},
2548 "error_result": error_result
2549 }
2550
2551 @staticmethod
Matthew Haddonc2025212021-10-08 21:21:05 +01002552 def evScaleTrue(check=False, **kwargs):
2553 error_name = ErrorIf.ScaleTrue
2554 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
2555 error_result = False
2556 error_reason = "Scale set to true but input type is INT48"
2557
2558 if check:
2559 input_dtype = kwargs['input_dtype']
2560 scale32 = kwargs['scale32']
2561 if scale32 and input_dtype == DType.INT48:
2562 error_result = True
2563
2564 info_dict = {
2565 "error_name": error_name,
2566 "error_result": error_result,
2567 "error_reason": error_reason,
2568 "param_reqs": param_reqs
2569 }
2570 return info_dict
2571
2572 @staticmethod
2573 def evScaleNotTrue(check=False, **kwargs):
2574 error_name = ErrorIf.ScaleNotTrue
2575 param_reqs = {"rank": None, "dtype": None, "shape": None}
2576 error_result = False
2577 error_reason = "Scale set to false but double round set to true"
2578
2579 if check:
2580 scale32 = kwargs['scale32']
2581 double_round = kwargs['double_round']
2582 if not scale32 and double_round:
2583 error_result = True
2584
2585 info_dict = {
2586 "error_name": error_name,
2587 "error_result": error_result,
2588 "error_reason": error_reason,
2589 "param_reqs": param_reqs
2590 }
2591 return info_dict
2592
Matthew Haddone807aae2021-10-11 18:12:58 +01002593 @staticmethod
2594 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
2595 error_name = ErrorIf.TensorSizeInputOutputMismatch
2596 param_reqs = {"rank": None, "dtype": None, "shape": None}
2597 error_result = False
2598 error_reason = "Input tensor size does not match output tensor size"
2599
2600 if check:
2601 input_shape = kwargs['input_shape']
2602 output_shape = kwargs['output_shape']
2603 input_size = np.prod(input_shape)
2604 output_size = np.prod(output_shape)
2605 if input_size != output_size:
2606 error_result = True
2607
2608 info_dict = {
2609 "error_name": error_name,
2610 "error_result": error_result,
2611 "error_reason": error_reason,
2612 "param_reqs": param_reqs
2613 }
2614 return info_dict
2615
2616 @staticmethod
2617 def evStartSmallerZero(check=False, **kwargs):
2618 error_name = ErrorIf.StartSmallerZero
2619 param_reqs = {"rank": None, "dtype": None, "shape": None}
2620 error_result = False
2621 error_reason = "Starting point smaller than zero"
2622
2623 if check:
2624 input_shape = kwargs['input_shape']
2625 start = kwargs['start']
2626 rank = len(input_shape)
2627 if len(start) == rank:
2628 for index in range(rank):
2629 if start[index] < 0:
2630 error_result = True
2631
2632 info_dict = {
2633 "error_name": error_name,
2634 "error_result": error_result,
2635 "error_reason": error_reason,
2636 "param_reqs": param_reqs
2637 }
2638 return info_dict
2639
2640
2641 @staticmethod
2642 def evSizeSmallerEqualZero(check=False, **kwargs):
2643 error_name = ErrorIf.SizeSmallerEqualZero
2644 param_reqs = {"rank": None, "dtype": None, "shape": None}
2645 error_result = False
2646 error_reason = "Size smaller than or equal to zero"
2647
2648 if check:
2649 input_shape = kwargs['input_shape']
2650 size = kwargs['size']
2651 rank = len(input_shape)
2652 if len(size) == rank:
2653 for index in range(rank):
2654 if size[index] <= 0:
2655 error_result = True
2656
2657 info_dict = {
2658 "error_name": error_name,
2659 "error_result": error_result,
2660 "error_reason": error_reason,
2661 "param_reqs": param_reqs
2662 }
2663 return info_dict
2664
2665
2666 @staticmethod
2667 def evStartSizeOutsideBounds(check=False, **kwargs):
2668 error_name = ErrorIf.StartSizeOutsideBounds
2669 param_reqs = {"rank": None, "dtype": None, "shape": None}
2670 error_result = False
2671 error_reason = "starting point plus size larger than input dimension"
2672
2673 if check:
2674 input_shape = kwargs['input_shape']
2675 start = kwargs['start']
2676 size = kwargs['size']
2677 rank = len(input_shape)
2678 if len(start) == rank and len(size) == rank:
2679 for index in range(rank):
2680 if start[index] + size[index] > input_shape[index]:
2681 error_result = True
2682
2683 info_dict = {
2684 "error_name": error_name,
2685 "error_result": error_result,
2686 "error_reason": error_reason,
2687 "param_reqs": param_reqs
2688 }
2689 return info_dict
2690
2691
2692 @staticmethod
2693 def evSizeOutputShapeMismatch(check=False, **kwargs):
2694 error_name = ErrorIf.SizeOutputShapeMismatch
2695 param_reqs = {"rank": None, "dtype": None, "shape": None}
2696 error_result = False
2697 error_reason = "Size does not match output dimension"
2698
2699 if check:
2700 input_shape = kwargs['input_shape']
2701 output_shape = kwargs['output_shape']
2702 size = kwargs['size']
2703 rank = len(input_shape)
2704 if len(size) == rank:
2705 for index in range(rank):
2706 if size[index] != output_shape[index]:
2707 error_result = True
2708
2709 info_dict = {
2710 "error_name": error_name,
2711 "error_result": error_result,
2712 "error_reason": error_reason,
2713 "param_reqs": param_reqs
2714 }
2715 return info_dict
2716
2717 @staticmethod
2718 def evInputSizeStartLengthMismatch(check=False, **kwargs):
2719 error_name = ErrorIf.InputSizeStartLengthMismatch
2720 param_reqs = {"rank": None, "dtype": None, "shape": None}
2721 error_result = False
2722 error_reason = "rank of input not equal to length of start or size"
2723
2724 if check:
2725 input_shape = kwargs['input_shape']
2726 start = kwargs['start']
2727 size = kwargs['size']
2728 rank = len(input_shape)
2729 if rank != len(start) or rank != len(size):
2730 error_result = True
2731
2732 info_dict = {
2733 "error_name": error_name,
2734 "error_result": error_result,
2735 "error_reason": error_reason,
2736 "param_reqs": param_reqs
2737 }
2738 return info_dict
2739
2740 @staticmethod
2741 def evIndexOutsideBounds(check=False, **kwargs):
2742 error_name = ErrorIf.IndexOutsideBounds
2743 param_reqs = {"rank": None, "dtype": None, "shape": None}
2744 error_result = False
2745 error_reason = "Index outside of allowed bounds"
2746
2747 if check:
2748 input_shape = kwargs['input_shape']
2749 perms = kwargs['perms']
2750 rank = len(input_shape)
2751
2752 for index in perms:
2753 if index < 0 or index > rank:
2754 error_result = True
2755
2756 info_dict = {
2757 "error_name": error_name,
2758 "error_result": error_result,
2759 "error_reason": error_reason,
2760 "param_reqs": param_reqs
2761 }
2762 return info_dict
2763
2764 @staticmethod
2765 def evIndexUsedTwice(check=False, **kwargs):
2766 error_name = ErrorIf.IndexUsedTwice
2767 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2768 error_result = False
2769 error_reason = "Index used multiple times"
2770
2771 if check:
2772 input_shape = kwargs['input_shape']
2773 perms = kwargs['perms']
2774 rank = len(input_shape)
2775
2776 unique_indices = []
2777 for index in perms:
2778 if index in unique_indices:
2779 error_result = True
2780 else:
2781 unique_indices.append(index)
2782
2783 info_dict = {
2784 "error_name": error_name,
2785 "error_result": error_result,
2786 "error_reason": error_reason,
2787 "param_reqs": param_reqs
2788 }
2789 return info_dict
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002790
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002791 @staticmethod
2792 def evMaxSmallerMin(check=False, **kwargs):
2793 error_name = ErrorIf.MaxSmallerMin
2794 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2795 error_result = False
2796 error_reason = "Max value smaller than min value"
2797
2798 if check:
2799 max_val = kwargs['max_val']
2800 min_val = kwargs['min_val']
2801 if max_val < min_val:
2802 error_result = True
2803
2804
2805 info_dict = {
2806 "error_name": error_name,
2807 "error_result": error_result,
2808 "error_reason": error_reason,
2809 "param_reqs": param_reqs
2810 }
2811 return info_dict
2812
2813 @staticmethod
2814 def evConcatInputRankMismatch(check=False, **kwargs):
2815 error_name = ErrorIf.ConcatInputRankMismatch
2816 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2817 error_result = False
2818 error_reason = "Input ranks are not identical"
2819
2820 if check:
2821 inputs = kwargs['inputs']
2822 input_shape = kwargs['input_shape']
2823 for input in inputs:
2824 if len(input.shape) != len(input_shape):
2825 error_result = True
2826
2827 info_dict = {
2828 "error_name": error_name,
2829 "error_result": error_result,
2830 "error_reason": error_reason,
2831 "param_reqs": param_reqs
2832 }
2833 return info_dict
2834
2835 @staticmethod
2836 def evConcatInputDimMismatch(check=False, **kwargs):
2837 error_name = ErrorIf.ConcatInputDimMismatch
2838 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2839 error_result = False
2840 error_reason = "Input dimensions differ on too many axes"
2841
2842 if check:
2843 inputs = kwargs['inputs']
2844 input_shape = kwargs['input_shape']
2845 axis = kwargs['axis']
2846
2847 # Ensure rank is valid before checking dims.
2848 valid_rank = True
2849 for input in inputs:
2850 if len(input.shape) != len(input_shape):
2851 valid_rank = False
2852
2853 if valid_rank:
2854 for input in inputs:
2855 for i, dim in enumerate(input.shape):
2856 if dim != input_shape[i] and axis != i:
2857 error_result = True
2858
2859 info_dict = {
2860 "error_name": error_name,
2861 "error_result": error_result,
2862 "error_reason": error_reason,
2863 "param_reqs": param_reqs
2864 }
2865 return info_dict
2866
Matthew Haddon630c17c2021-10-14 15:05:41 +01002867 @staticmethod
Matthew Haddon01c359d2021-10-15 16:30:48 +01002868 def evConcatShapeSumMismatch(check=False, **kwargs):
2869 error_name = ErrorIf.ConcatShapeSumMismatch
2870 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2871 error_result = False
2872 error_reason = "Sum of dimensions on axis not equal to output dimension"
2873
2874 if check:
2875 inputs = kwargs['inputs']
2876 input_shape = kwargs['input_shape']
2877 output_shape = kwargs['output_shape']
2878 axis = kwargs['axis']
2879
2880 # Ensure rank is valid before checking dims.
2881 valid_params = True
2882 for input in inputs:
2883 if len(input.shape) != len(input_shape):
2884 valid_params = False
2885 if axis < 0 or axis > len(input_shape):
2886 valid_params = False
2887
2888 if valid_params:
2889 axis_dim_sum = 0
2890 for input in inputs:
2891 axis_dim_sum += input.shape[axis]
2892
2893 if axis_dim_sum != output_shape[axis]:
2894 error_result = True
2895
2896
2897 info_dict = {
2898 "error_name": error_name,
2899 "error_result": error_result,
2900 "error_reason": error_reason,
2901 "param_reqs": param_reqs
2902 }
2903 return info_dict
2904
2905 @staticmethod
Matthew Haddon630c17c2021-10-14 15:05:41 +01002906 def evInputListThenGraphMismatch(check=False, **kwargs):
2907 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2908 param_reqs = {"rank": None, "dtype": None, "shape": None}
2909 error_result = False
2910 error_reason = "Input list shape does not match then-graph shape"
2911
2912 if check:
2913 a = kwargs['a']
2914 b = kwargs['b']
2915 basicBlocks = kwargs['basicBlocks']
2916 then_block = basicBlocks[1]
2917 then_inputs = then_block.inputs
2918 then_tens = then_block.tensors
2919 if (a.shape != then_tens[then_inputs[0]].shape) or (b.shape != then_tens[then_inputs[1]].shape):
2920 error_result = True
2921
2922 info_dict = {
2923 "error_name": error_name,
2924 "error_result": error_result,
2925 "error_reason": error_reason,
2926 "param_reqs": param_reqs
2927 }
2928 return info_dict
2929
2930
2931 @staticmethod
2932 def evInputListElseGraphMismatch(check=False, **kwargs):
2933 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2934 param_reqs = {"rank": None, "dtype": None, "shape": None}
2935 error_result = False
2936 error_reason = "Input list shape does not match else-graph shape"
2937
2938 if check:
2939 a = kwargs['a']
2940 b = kwargs['b']
2941 basicBlocks = kwargs['basicBlocks']
2942 else_block = basicBlocks[2]
2943 else_inputs = else_block.inputs
2944 else_tens = else_block.tensors
2945 if (a.shape != else_tens[else_inputs[0]].shape) or (b.shape != else_tens[else_inputs[1]].shape):
2946 error_result = True
2947
2948 info_dict = {
2949 "error_name": error_name,
2950 "error_result": error_result,
2951 "error_reason": error_reason,
2952 "param_reqs": param_reqs
2953 }
2954 return info_dict
2955
2956
2957 @staticmethod
2958 def evOutputListThenGraphMismatch(check=False, **kwargs):
2959 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2960 param_reqs = {"rank": None, "dtype": None, "shape": None}
2961 error_result = False
2962 error_reason = "Output list shape does not match then-graph shape"
2963
2964 if check:
2965 basicBlocks = kwargs['basicBlocks']
2966 cond_block = basicBlocks[0]
2967 cond_outputs = cond_block.outputs
2968 cond_tens = cond_block.tensors
2969 then_block = basicBlocks[1]
2970 then_outputs = then_block.outputs
2971 then_tens = then_block.tensors
2972 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2973 error_result = True
2974
2975 info_dict = {
2976 "error_name": error_name,
2977 "error_result": error_result,
2978 "error_reason": error_reason,
2979 "param_reqs": param_reqs
2980 }
2981 return info_dict
2982
2983
2984 @staticmethod
2985 def evOutputListElseGraphMismatch(check=False, **kwargs):
2986 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2987 param_reqs = {"rank": None, "dtype": None, "shape": None}
2988 error_result = False
2989 error_reason = "Output list shape does not match else-graph shape"
2990
2991 if check:
2992 basicBlocks = kwargs['basicBlocks']
2993 cond_block = basicBlocks[0]
2994 cond_outputs = cond_block.outputs
2995 cond_tens = cond_block.tensors
2996 else_block = basicBlocks[2]
2997 else_outputs = else_block.outputs
2998 else_tens = else_block.tensors
2999 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
3000 error_result = True
3001
3002 info_dict = {
3003 "error_name": error_name,
3004 "error_result": error_result,
3005 "error_reason": error_reason,
3006 "param_reqs": param_reqs
3007 }
3008 return info_dict
3009
3010
3011 @staticmethod
3012 def evInputListOutputListMismatch(check=False, **kwargs):
3013 error_name = ErrorIf.InputListOutputListMismatch
3014 param_reqs = {"rank": None, "dtype": None, "shape": None}
3015 error_result = False
3016 error_reason = "Input list does not match output list"
3017
3018 if check:
3019 basicBlocks = kwargs['basicBlocks']
3020 while_block = basicBlocks[0]
3021 while_inputs = while_block.inputs
3022 while_outputs = while_block.outputs
3023 while_tens = while_block.tensors
3024 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
3025 error_result = True
3026
3027 info_dict = {
3028 "error_name": error_name,
3029 "error_result": error_result,
3030 "error_reason": error_reason,
3031 "param_reqs": param_reqs
3032 }
3033 return info_dict
3034
3035
3036 @staticmethod
3037 def evInputListCondGraphMismatch(check=False, **kwargs):
3038 error_name = ErrorIf.InputListCondGraphMismatch
3039 param_reqs = {"rank": None, "dtype": None, "shape": None}
3040 error_result = False
3041 error_reason = "Input list does not match cond graph"
3042
3043 if check:
3044 basicBlocks = kwargs['basicBlocks']
3045 while_block = basicBlocks[0]
3046 while_inputs = while_block.inputs
3047 while_tens = while_block.tensors
3048 cond_block = basicBlocks[1]
3049 cond_inputs = cond_block.inputs
3050 cond_tens = cond_block.tensors
3051 if ((while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape) or
3052 (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape)):
3053 error_result = True
3054
3055 info_dict = {
3056 "error_name": error_name,
3057 "error_result": error_result,
3058 "error_reason": error_reason,
3059 "param_reqs": param_reqs
3060 }
3061 return info_dict
3062
3063
3064 @staticmethod
3065 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
3066 error_name = ErrorIf.InputListBodyGraphInputMismatch
3067 param_reqs = {"rank": None, "dtype": None, "shape": None}
3068 error_result = False
3069 error_reason = "Input list does not match body graph input"
3070
3071 if check:
3072 basicBlocks = kwargs['basicBlocks']
3073 while_block = basicBlocks[0]
3074 while_inputs = while_block.inputs
3075 while_tens = while_block.tensors
3076 body_block = basicBlocks[2]
3077 body_outputs = body_block.inputs
3078 body_tens = body_block.tensors
3079 if ((while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape) or
3080 (while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape)):
3081 error_result = True
3082
3083 info_dict = {
3084 "error_name": error_name,
3085 "error_result": error_result,
3086 "error_reason": error_reason,
3087 "param_reqs": param_reqs
3088 }
3089 return info_dict
3090
3091
3092 @staticmethod
3093 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
3094 error_name = ErrorIf.InputListBodyGraphOutputMismatch
3095 param_reqs = {"rank": None, "dtype": None, "shape": None}
3096 error_result = False
3097 error_reason = "Input list does not match body graph output"
3098
3099 if check:
3100 basicBlocks = kwargs['basicBlocks']
3101 while_block = basicBlocks[0]
3102 while_inputs = while_block.inputs
3103 while_tens = while_block.tensors
3104 body_block = basicBlocks[2]
3105 body_outputs = body_block.outputs
3106 body_tens = body_block.tensors
3107 if ((while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape) or
3108 (while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape)):
3109 error_result = True
3110 info_dict = {
3111 "error_name": error_name,
3112 "error_result": error_result,
3113 "error_reason": error_reason,
3114 "param_reqs": param_reqs
3115 }
3116 return info_dict
3117
3118
3119 @staticmethod
3120 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
3121 error_name = ErrorIf.CondGraphOutputNotMatchingBool
3122 param_reqs = {"rank": None, "dtype": None, "shape": None}
3123 error_result = False
3124 error_reason = "Cond graph output is not a match list of booleans"
3125
3126 if check:
3127 basicBlocks = kwargs['basicBlocks']
3128 cond_block = basicBlocks[1]
3129 cond_outputs = cond_block.outputs
3130 cond_tens = cond_block.tensors
3131 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
3132 error_result = True
3133
3134 info_dict = {
3135 "error_name": error_name,
3136 "error_result": error_result,
3137 "error_reason": error_reason,
3138 "param_reqs": param_reqs
3139 }
3140 return info_dict
3141
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003142
Matthew Haddonb724efc2021-08-25 16:40:29 +01003143class TosaInvalidValidator:
3144
3145 @staticmethod
3146 def ivWrongDataTypeOrModeResize(**kwargs):
3147 input_dtype = kwargs["input_dtype"]
3148 args = kwargs["args"]
3149 mode = args[0]
3150 stride = args[1]
3151 stride_fp = args[4]
3152 output_dtype = args[8]
3153
3154 if mode == ResizeMode.BILINEAR:
3155 # Invalid output data type / Invalid input datatype
3156 return (
3157 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
3158 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
3159 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
3160 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
3161 )
3162 elif mode == ResizeMode.NEAREST:
3163 # Invalid output data type / Invalid input datatype
3164 return (
3165 (input_dtype != output_dtype) or
3166 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
3167 )
3168 else:
3169 # Invalid resize mode
3170 return True
3171
3172 @staticmethod
3173 def ivBadStride(**kwargs):
3174 input_dtype = kwargs["input_dtype"]
3175 args = kwargs["args"]
3176 stride_x = args[1][0]
3177 stride_y = args[1][1]
3178 stride_fp_x = args[4][0]
3179 stride_fp_y = args[4][1]
3180
3181 if input_dtype == DType.FLOAT:
3182 if stride_fp_x <= 0 or stride_fp_y <= 0:
3183 # Negative or zero stride
3184 return True
3185 else:
3186 if stride_x <= 0 or stride_y <= 0:
3187 # Negative or zero stride
3188 return True
3189 return False
3190
Matthew Haddonb724efc2021-08-25 16:40:29 +01003191 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00003192 def ivHeightWidthInvalid(**kwargs):
Matthew Haddonb724efc2021-08-25 16:40:29 +01003193 opName = kwargs['opName']
3194
3195 inputShapes = kwargs['shapeList']
Les Bell0e027d42021-11-09 14:42:14 +00003196 input_shape = inputShapes[0]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003197
3198 args = kwargs['args']
3199 strides = args[0]
3200 padding = args[1]
Les Bell0e027d42021-11-09 14:42:14 +00003201
Matthew Haddonb724efc2021-08-25 16:40:29 +01003202 if opName.endswith("pool2d"):
Les Bell0e027d42021-11-09 14:42:14 +00003203 # avg_pool2d, max_pool2d
3204 kernel_shape = args[2]
3205 h = (input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]) // strides[0]
3206 w = (input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]) // strides[1]
3207 # return True if any dimension is < 1
3208 return h < 1 or w < 1
Matthew Haddonb724efc2021-08-25 16:40:29 +01003209
Les Bell0e027d42021-11-09 14:42:14 +00003210 if opName.startswith("transpose_conv2d"):
3211 # transpose_conv2d
3212 dilations = args[2]
3213 output_shape = args[3]
3214 filter_shape = inputShapes[1]
3215 kernel_shape = filter_shape[1:-1]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003216
Les Bell0e027d42021-11-09 14:42:14 +00003217 def get_out_size(in_size, stride, kernel_size, dilation, out_pad, in_pad):
3218 """Calculate the transpose_conv2d output size for a dimension.
Matthew Haddonb724efc2021-08-25 16:40:29 +01003219
Les Bell0e027d42021-11-09 14:42:14 +00003220 Based on the keras function deconv_output_length, in
3221 https://github.com/keras-team/keras/blob/master/keras/utils/conv_utils.py
Matthew Haddonb724efc2021-08-25 16:40:29 +01003222
Les Bell0e027d42021-11-09 14:42:14 +00003223 Args:
3224 in_size: the input size - int
3225 stride: the stride - int
3226 kernel_size: the kernel size - int
3227 dilation: the kernel dilation - int
3228 out_pad: the output padding - int
3229 in_pad: the input padding - int
3230
3231 Returns:
3232 the output size
3233 """
3234 dilated_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
3235 return (in_size - 1) * stride + dilated_kernel_size - 2 * in_pad + out_pad
3236
3237 for pad_h, pad_w in (
3238 (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
3239 (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
3240 (0, 0) # VALID padding
3241 ):
3242 h = get_out_size(input_shape[1], strides[0], kernel_shape[0], dilations[0],
3243 padding[0], pad_h)
3244 w = get_out_size(input_shape[2], strides[1], kernel_shape[1], dilations[1],
3245 padding[1], pad_w)
3246 if output_shape[1] == h and output_shape[2] == w:
3247 return False
3248
3249 # output shape does not match the expected shape for any padding option
Matthew Haddonb724efc2021-08-25 16:40:29 +01003250 return True
Les Bell0e027d42021-11-09 14:42:14 +00003251
3252 if "conv2d" in opName or "conv3d" in opName:
3253 # conv2d, conv3d, depthwise_conv2d
3254 dilations = args[2]
3255 filter_shape = inputShapes[1]
3256 kernel_shape = filter_shape[0:2] if opName.startswith("depthwise_conv2d") else filter_shape[1:-1]
3257
3258 for i in range(len(kernel_shape)):
3259 dim = (
3260 input_shape[i + 1]
3261 - kernel_shape[i]
3262 - (kernel_shape[i] - 1) * (dilations[i] - 1)
3263 + padding[i * 2 + 0]
3264 + padding[i * 2 + 1]
3265 ) // strides[i] + 1
3266 # return True if any dimension is < 1
3267 if dim < 1:
3268 return True
3269 return False
3270
3271 assert False, f"Unrecognized Op: {opName}"
Matthew Haddonb724efc2021-08-25 16:40:29 +01003272
3273 @staticmethod
3274 def ivNonPositiveOutputShape(**kwargs):
3275 args = kwargs['args']
3276 output_shape = args[3]
3277 if output_shape[1] <= 0 or output_shape[2] <= 0:
3278 # Negative output shape
3279 return True
3280 return False
3281
3282
Eric Kunzee5e26762020-10-13 16:11:07 -07003283class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003284 # Maximum rank of tensor supported by test generator.
3285 TOSA_TENSOR_MAX_RANK = 6
3286
Eric Kunzee5e26762020-10-13 16:11:07 -07003287 def __init__(self, args):
3288 self.args = args
3289 self.basePath = args.output_dir
3290 self.random_seed = args.random_seed
3291 self.ser = None
3292 self.rng = np.random.default_rng(self.random_seed)
3293 self.createDynamicOpLists()
3294 self.initOpListDefaults()
3295 self.quantGen = TosaQuantGen()
3296 # Force makeShape to do a specific starting shape
3297 self.targetted_shape = None
3298
3299 def createSerializer(self, opName, testPath):
3300 self.testPath = os.path.join(opName, testPath)
3301
3302 fullPath = os.path.join(self.basePath, self.testPath)
3303 os.makedirs(fullPath, exist_ok=True)
3304 self.ser = ts.TosaSerializer(fullPath)
3305
3306 def getSerializer(self):
3307 return self.ser
3308
3309 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003310 with open(
3311 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
3312 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07003313 fd.write(self.ser.serialize())
3314
Kevin Cheng550ccc52021-03-03 11:21:43 -08003315 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
3316 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07003317
Matthew Haddon74567092021-07-16 15:38:20 +01003318 def resetRNG(self, seed=None):
3319 if seed == None:
3320 seed = self.random_seed + 1
3321 self.rng = np.random.default_rng(seed)
3322
Eric Kunzee5e26762020-10-13 16:11:07 -07003323 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07003324 if dtype == DType.BOOL:
3325 np_dt = np.bool
3326 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07003327 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07003328 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07003329 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003330 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01003331 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
3332 elif dtype == DType.UINT8:
3333 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003334 elif dtype == DType.INT16:
3335 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
3336 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003337 return np.int32(
3338 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
3339 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003340 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003341 return np.int64(
3342 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
3343 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003344 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01003345 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003346 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003347 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003348
Kevin Cheng989cb052021-04-28 16:29:44 -07003349 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07003350 placeholders = []
3351
Kevin Cheng989cb052021-04-28 16:29:44 -07003352 assert len(shape_list) == len(dtype_list)
3353
3354 for idx, shape in enumerate(shape_list):
3355 arr = self.getRandTensor(shape, dtype_list[idx])
3356 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07003357
3358 return placeholders
3359
Kevin Cheng989cb052021-04-28 16:29:44 -07003360 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07003361 consts = []
3362
Kevin Cheng989cb052021-04-28 16:29:44 -07003363 assert len(shape_list) == len(dtype_list)
3364
3365 for idx, shape in enumerate(shape_list):
3366 arr = self.getRandTensor(shape, dtype_list[idx])
3367 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07003368
3369 return consts
3370
3371 def makeShape(self, rank):
3372 if self.targetted_shape:
3373 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003374 return np.int32(
3375 self.rng.integers(
3376 low=self.args.tensor_shape_range[0],
3377 high=self.args.tensor_shape_range[1],
3378 size=rank,
3379 )
3380 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003381
3382 def setTargetShape(self, shape):
3383 self.targetted_shape = shape
3384
3385 def randInt(self, low=0, high=256):
3386 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
3387
3388 def getRandNumberDType(self, dtype):
3389 if dtype == DType.FLOAT:
3390 return self.rng.random()
3391 elif dtype == DType.BOOL:
3392 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07003393 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07003394 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07003395 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07003396 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01003397 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07003398 elif dtype == DType.INT16:
3399 low, high = (-32768, 32768)
3400 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003401 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07003402 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003403 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07003404 # Special size
3405 return np.int64(self.rng.integers(low, high, size=1))[0]
3406 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003407 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003408
3409 return np.int32(self.rng.integers(low, high, size=1))[0]
3410
3411 def shapeStr(self, shape):
3412
3413 sStr = []
3414 # Convert to strings
3415 for i in shape:
3416 sStr.append(str(i))
3417
Kevin Cheng550ccc52021-03-03 11:21:43 -08003418 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003419
3420 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07003421 if isinstance(t, list):
3422 assert len(t) >= 2
3423 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07003424 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07003425 if t == DType.BOOL:
3426 return "b"
3427 elif t == DType.INT4:
3428 return "i4"
3429 elif t == DType.INT8:
3430 return "i8"
3431 elif t == DType.UINT8:
3432 return "u8"
3433 elif t == DType.INT16:
3434 return "i16"
3435 elif t == DType.INT32:
3436 return "i32"
3437 elif t == DType.INT48:
3438 return "i48"
3439 elif t == DType.FLOAT:
3440 return "float"
3441 else:
3442 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07003443
3444 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003445 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08003446 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07003447 return 4
3448 elif t == DType.INT8:
3449 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08003450 elif t == DType.UINT8:
3451 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07003452 elif t == DType.INT16:
3453 return 16
3454 elif t == DType.INT32:
3455 return 32
3456 elif t == DType.INT48:
3457 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +01003458 elif t == DType.FLOAT:
3459 return 32
3460 elif t == DType.BOOL:
3461 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003462 else:
Les Bell729b0352021-11-24 10:28:21 +00003463 raise Exception(f"Unknown dtype, cannot determine width: {t}")
Eric Kunzee5e26762020-10-13 16:11:07 -07003464
3465 # Argument generators
3466 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
3467 # Where the string descriptor is used to generate the test name and
3468 # The build_fcn_arg_list is expanded and passed to the operator test
3469 # build function
3470
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003471 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
3472 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
3473
Matthew Haddon848efb42021-09-09 12:30:53 +01003474 # build_placeholder returns an int, ABS/other ops does not
3475 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003476 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
3477 return result_tens
3478 elif op['op'] == Op.IDENTITY:
3479 self.ser.addOperator(op['op'], a.name, result_tens.name, None, qinfo)
3480 return result_tens
3481
3482 # Ensure new output type has correct qinfo
3483 if error_name == ErrorIf.WrongOutputType:
3484 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
3485 qinfo = ts.TosaSerializerQuantInfo()
3486 qinfo.UnaryQuantInfo(
3487 TosaQuantGen.getQinfo(self, a.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
3488 )
3489
3490 # Invalidate Input/Output list for error if checks.
3491 input_list = [a.name]
3492 output_list = [result_tens.name]
3493 pCount, cCount = op["operands"]
3494 num_operands = pCount + cCount
3495 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3496
Les Bell729b0352021-11-24 10:28:21 +00003497 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003498 self.ser,
3499 validator_fcns,
3500 error_name,
3501 op=op,
3502 input_dtype=a.dtype,
3503 output_dtype=result_tens.dtype,
3504 qinfo = qinfo,
3505 result_tensor = result_tens,
3506 input_list=input_list,
3507 output_list=output_list,
3508 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003509 ):
3510 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003511
3512 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003513 return result_tens
3514
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003515 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
3516 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
3517
3518
3519 # Invalidate Input/Output list for error if checks.
3520 input_list = [a.name, b.name]
3521 output_list = [result_tens.name]
3522 pCount, cCount = op["operands"]
3523 num_operands = pCount + cCount
3524 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3525
Les Bell729b0352021-11-24 10:28:21 +00003526 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003527 self.ser,
3528 validator_fcns,
3529 error_name,
3530 op=op,
3531 input1 = a,
3532 input2 = b,
3533 input_dtype = a.dtype,
3534 output_dtype = result_tens.dtype,
3535 result_tensor = result_tens,
3536 input_list=input_list,
3537 output_list=output_list,
3538 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003539 ):
3540 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003541
3542 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07003543 return result_tens
3544
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003545 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003546 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01003547 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003548 return result_tens
3549
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003550 def build_arithmetic_right_shift(self, op, a, b, round, validator_fcns=None, error_name=None):
3551 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
3552
3553 # Invalidate Input/Output list for error if checks.
3554 input_list = [a.name, b.name]
3555 output_list = [result_tens.name]
3556 pCount, cCount = op["operands"]
3557 num_operands = pCount + cCount
3558 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3559
Les Bell729b0352021-11-24 10:28:21 +00003560 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003561 self.ser,
3562 validator_fcns,
3563 error_name,
3564 op=op,
3565 input1 = a,
3566 input2 = b,
3567 input_dtype = a.dtype,
3568 output_dtype = result_tens.dtype,
3569 result_tensor = result_tens,
3570 input_list=input_list,
3571 output_list=output_list,
3572 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003573 ):
3574 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -08003575
3576 attr = ts.TosaSerializerAttribute()
3577 attr.ArithmeticRightShiftAttribute(round)
3578
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003579 self.ser.addOperator(op['op'], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08003580 return result_tens
3581
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003582 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
3583 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003584
3585 # Special for multiply:
3586 # Force the result to INT32 for INT types
3587 if a.dtype != DType.FLOAT:
3588 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003589 if error_name == ErrorIf.WrongOutputType:
3590 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
3591 outputDType = self.rng.choice(all_dtypes)
3592 result_tens.setDtype(outputDType)
3593
3594 # Invalidate Input/Output list for error if checks.
3595 input_list = [a.name, b.name]
3596 output_list = [result_tens.name]
3597 pCount, cCount = op["operands"]
3598 num_operands = pCount + cCount
3599 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3600
Les Bell729b0352021-11-24 10:28:21 +00003601 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003602 self.ser,
3603 validator_fcns,
3604 error_name,
3605 op=op,
3606 input1 = a,
3607 input2 = b,
3608 input_dtype = a.dtype,
3609 output_dtype = result_tens.dtype,
3610 result_tensor = result_tens,
3611 input_list=input_list,
3612 output_list=output_list,
3613 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003614 ):
3615 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07003616
Kevin Chengaee1fac2020-11-11 13:54:06 -08003617 attr = ts.TosaSerializerAttribute()
3618 attr.MulAttribute(shift)
3619
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003620 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003621 return result_tens
3622
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003623 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
3624 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003625
Kevin Chengfe392ce2021-10-18 21:51:55 +00003626 attr = ts.TosaSerializerAttribute()
3627 attr.TableAttribute(table)
3628
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003629 # Invalidate Input/Output list for error if checks.
3630 input_list = [a.name]
3631 output_list = [result_tens.name]
3632 pCount, cCount = op["operands"]
3633 num_operands = pCount + cCount
3634 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3635
Les Bell729b0352021-11-24 10:28:21 +00003636 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003637 self.ser,
3638 validator_fcns,
3639 error_name,
3640 op=op,
3641 input_shape = a.shape,
3642 input_dtype = a.dtype,
3643 output_dtype = result_tens.dtype,
3644 result_tensor = result_tens,
3645 input_list=input_list,
3646 output_list=output_list,
3647 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003648 ):
3649 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003650
3651 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003652
3653 return result_tens
3654
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003655 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
3656 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
3657
3658 # Invalidate Input/Output list for error if checks.
3659 input_list = [cond.name, a.name, b.name]
3660 output_list = [result_tens.name]
3661 pCount, cCount = op["operands"]
3662 num_operands = pCount + cCount
3663 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3664
Les Bell729b0352021-11-24 10:28:21 +00003665 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003666 self.ser,
3667 validator_fcns,
3668 error_name,
3669 op=op,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003670 input1 = cond,
3671 input2 = a,
3672 input3 = b,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003673 input_shape = a.shape,
3674 input_dtype = a.dtype,
3675 output_dtype = result_tens.dtype,
3676 result_tensor = result_tens,
3677 input_list=input_list,
3678 output_list=output_list,
3679 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003680 ):
3681 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003682
3683 self.ser.addOperator(op['op'], input_list, output_list,)
Eric Kunzee5e26762020-10-13 16:11:07 -07003684 return result_tens
3685
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003686 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
3687 result_tens = OutputShaper.binaryComparisonOp(self.ser, self.rng, a, b, error_name)
3688
3689 # Invalidate Input/Output list for error if checks.
3690 input_list = [a.name, b.name]
3691 output_list = [result_tens.name]
3692 pCount, cCount = op["operands"]
3693 num_operands = pCount + cCount
3694 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3695
Les Bell729b0352021-11-24 10:28:21 +00003696 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003697 self.ser,
3698 validator_fcns,
3699 error_name,
3700 op=op,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003701 input1 = a,
3702 input2 = b,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003703 input_shape = a.shape,
3704 input_dtype = a.dtype,
3705 output_shape = result_tens.shape,
3706 output_dtype = result_tens.dtype,
3707 result_tensor = result_tens,
3708 input_list=input_list,
3709 output_list=output_list,
3710 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003711 ):
3712 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003713
3714 self.ser.addOperator(op['op'], input_list, output_list,)
Eric Kunzee5e26762020-10-13 16:11:07 -07003715 return result_tens
3716
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003717 def build_argmax(self, op, a, axis, validator_fcns, error_name):
3718 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
3719
3720 # Invalidate Input/Output list for error if checks.
3721 input_list = [a.name]
3722 output_list = [result_tens.name]
3723 pCount, cCount = op["operands"]
3724 num_operands = pCount + cCount
3725 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3726
Les Bell729b0352021-11-24 10:28:21 +00003727 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003728 self.ser,
3729 validator_fcns,
3730 error_name,
3731 op=op,
3732 axis=axis,
3733 input_shape = a.shape,
3734 input_dtype = a.dtype,
3735 output_shape = result_tens.shape,
3736 output_dtype = result_tens.dtype,
3737 result_tensor = result_tens,
3738 input_list=input_list,
3739 output_list=output_list,
3740 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003741 ):
3742 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07003743
3744 attr = ts.TosaSerializerAttribute()
3745 attr.AxisAttribute(axis)
3746
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003747 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003748 return result_tens
3749
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003750 def build_pool2d(self, op, input, stride, pad, kernel, validator_fcns=None, error_name=None, qinfo=None):
3751 result_tens = OutputShaper.pool2dOp(self.ser, self.rng, input, kernel, stride, pad, error_name)
3752
3753 # Ensure new output type has correct qinfo
3754 if error_name == ErrorIf.WrongInputType:
3755 if input.dtype not in [DType.INT8, DType.UINT8]:
3756 qinfo = ts.TosaSerializerQuantInfo()
3757 qinfo.UnaryQuantInfo(
Les Bell0e027d42021-11-09 14:42:14 +00003758 TosaQuantGen.getQinfo(self, input.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003759 )
3760
3761 # Invalidate Input/Output list for error if checks.
3762 input_list = [input.name]
3763 output_list = [result_tens.name]
3764 pCount, cCount = op["operands"]
3765 num_operands = pCount + cCount
3766 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3767
Les Bell729b0352021-11-24 10:28:21 +00003768 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003769 self.ser,
3770 validator_fcns,
3771 error_name,
3772 op=op,
3773 input_shape=input.shape,
3774 input_dtype=input.dtype,
3775 output_shape=result_tens.shape,
3776 output_dtype=result_tens.dtype,
3777 kernel=kernel,
3778 stride=stride,
3779 pad=pad,
3780 qinfo = qinfo,
3781 result_tensor = result_tens,
3782 input_list=input_list,
3783 output_list=output_list,
3784 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003785 ):
3786 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07003787
3788 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003789 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07003790
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003791 self.ser.addOperator(op['op'], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003792 return result_tens
3793
Les Bell0e027d42021-11-09 14:42:14 +00003794 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003795 assert len(padding) == 4
3796 result_tens = OutputShaper.conv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +00003797 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
3798 )
3799
3800 # Ensure new output type has correct qinfo
3801 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
3802 qinfo = ts.TosaSerializerQuantInfo()
3803 qinfo.ConvQuantInfo(
3804 TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
3805 )
3806
3807 # Invalidate Input/Output list for error_if checks.
3808 input_list = [ifm.name, filter.name, bias.name]
3809 output_list = [result_tens.name]
3810 num_operands = sum(op["operands"])
3811 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3812
Les Bell729b0352021-11-24 10:28:21 +00003813 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00003814 self.ser,
3815 validator_fcns,
3816 error_name,
3817 op=op,
3818 input_dtype=ifm.dtype,
3819 weight_dtype=filter.dtype,
3820 output_dtype=result_tens.dtype,
3821 qinfo=qinfo,
3822 input_list=input_list,
3823 num_operands=num_operands,
3824 output_list=output_list,
3825 pad=padding,
3826 stride=strides,
3827 dilation=dilations,
3828 input_shape=ifm.shape,
Les Bell729b0352021-11-24 10:28:21 +00003829 ):
3830 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07003831
3832 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003833 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07003834
Kevin Cheng550ccc52021-03-03 11:21:43 -08003835 self.ser.addOperator(
Les Bell0e027d42021-11-09 14:42:14 +00003836 op['op'], input_list, output_list, attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003837 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003838 return result_tens
3839
Les Bell0e027d42021-11-09 14:42:14 +00003840 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None):
Kevin Cheng1533b852021-09-01 12:51:58 -07003841 assert len(padding) == 6
3842 result_tens = OutputShaper.conv3dOp(
Les Bell0e027d42021-11-09 14:42:14 +00003843 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
3844 )
3845
3846 # Ensure new output type has correct qinfo
3847 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
3848 qinfo = ts.TosaSerializerQuantInfo()
3849 qinfo.ConvQuantInfo(
3850 TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
3851 )
3852
3853 # Invalidate Input/Output list for error_if checks.
3854 input_list = [ifm.name, filter.name, bias.name]
3855 output_list = [result_tens.name]
3856 num_operands = sum(op["operands"])
3857 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3858
Les Bell729b0352021-11-24 10:28:21 +00003859 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00003860 self.ser,
3861 validator_fcns,
3862 error_name,
3863 op=op,
3864 input_dtype=ifm.dtype,
3865 weight_dtype=filter.dtype,
3866 output_dtype=result_tens.dtype,
3867 qinfo=qinfo,
3868 input_list=input_list,
3869 num_operands=num_operands,
3870 output_list=output_list,
3871 pad=padding,
3872 stride=strides,
3873 dilation=dilations,
3874 input_shape=ifm.shape,
Les Bell729b0352021-11-24 10:28:21 +00003875 ):
3876 return None
Kevin Cheng1533b852021-09-01 12:51:58 -07003877
3878 attr = ts.TosaSerializerAttribute()
3879 attr.ConvAttribute(padding, strides, dilations)
3880
3881 self.ser.addOperator(
Les Bell0e027d42021-11-09 14:42:14 +00003882 op['op'], input_list, output_list, attr, qinfo
Kevin Cheng1533b852021-09-01 12:51:58 -07003883 )
3884 return result_tens
3885
Kevin Cheng550ccc52021-03-03 11:21:43 -08003886 def build_transpose_conv2d(
Les Bell0e027d42021-11-09 14:42:14 +00003887 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, validator_fcns=None, error_name=None, qinfo=None
Kevin Cheng550ccc52021-03-03 11:21:43 -08003888 ):
3889 assert len(outpad) == 2
Les Bell0e027d42021-11-09 14:42:14 +00003890 result_tens = OutputShaper.transposeConv2DOp(self.ser, self.rng, ifm, output_shape, error_name)
3891
3892 # Ensure new output type has correct qinfo
3893 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
3894 qinfo = ts.TosaSerializerQuantInfo()
3895 qinfo.ConvQuantInfo(
3896 TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
3897 )
3898
3899 # Invalidate Input/Output list for error_if checks.
3900 input_list = [ifm.name, filter.name, bias.name]
3901 output_list = [result_tens.name]
3902 num_operands = sum(op["operands"])
3903 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3904
Les Bell729b0352021-11-24 10:28:21 +00003905 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00003906 self.ser,
3907 validator_fcns,
3908 error_name,
3909 op=op,
3910 input_dtype=ifm.dtype,
3911 weight_dtype=filter.dtype,
3912 output_dtype=result_tens.dtype,
3913 qinfo=qinfo,
3914 input_list=input_list,
3915 num_operands=num_operands,
3916 output_list=output_list,
3917 pad=outpad,
3918 stride=stride,
3919 dilation=dilation,
3920 input_shape=ifm.shape,
Les Bell729b0352021-11-24 10:28:21 +00003921 ):
3922 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07003923
3924 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003925 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003926
Kevin Cheng550ccc52021-03-03 11:21:43 -08003927 self.ser.addOperator(
Les Bell0e027d42021-11-09 14:42:14 +00003928 op['op'], input_list, output_list, attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003929 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003930 return result_tens
3931
Kevin Cheng550ccc52021-03-03 11:21:43 -08003932 def build_depthwise_conv2d(
Les Bell0e027d42021-11-09 14:42:14 +00003933 self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None
Kevin Cheng550ccc52021-03-03 11:21:43 -08003934 ):
3935 result_tens = OutputShaper.depthwiseConv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +00003936 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
3937 )
3938
3939 # Ensure new output type has correct qinfo
3940 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
3941 qinfo = ts.TosaSerializerQuantInfo()
3942 qinfo.ConvQuantInfo(
3943 TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
3944 )
3945
3946 # Invalidate Input/Output list for error_if checks.
3947 input_list = [ifm.name, filter.name, bias.name]
3948 output_list = [result_tens.name]
3949 num_operands = sum(op["operands"])
3950 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3951
Les Bell729b0352021-11-24 10:28:21 +00003952 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00003953 self.ser,
3954 validator_fcns,
3955 error_name,
3956 op=op,
3957 input_dtype=ifm.dtype,
3958 weight_dtype=filter.dtype,
3959 output_dtype=result_tens.dtype,
3960 qinfo=qinfo,
3961 input_list=input_list,
3962 num_operands=num_operands,
3963 output_list=output_list,
3964 pad=padding,
3965 stride=strides,
3966 dilation=dilations,
3967 input_shape=ifm.shape,
Les Bell729b0352021-11-24 10:28:21 +00003968 ):
3969 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07003970
3971 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003972 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07003973
Kevin Cheng550ccc52021-03-03 11:21:43 -08003974 self.ser.addOperator(
Les Bell0e027d42021-11-09 14:42:14 +00003975 op['op'], input_list, output_list, attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003976 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003977 return result_tens
3978
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003979 def build_fully_connected(self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None):
3980 result_tens = OutputShaper.fullyConnectedOp(self.ser, self.rng, ifm, filter, error_name)
3981
3982 # Invalidate Input/Output list for error if checks.
3983 input_list = [ifm.name, filter.name, bias.name]
3984 output_list = [result_tens.name]
3985 pCount, cCount = op["operands"]
3986 num_operands = pCount + cCount
3987 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3988
Les Bell729b0352021-11-24 10:28:21 +00003989 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003990 self.ser,
3991 validator_fcns,
3992 error_name,
3993 op=op,
3994 input_shape=ifm.shape,
3995 input_dtype=ifm.dtype,
3996 weight_dtype=filter.dtype,
3997 output_shape=result_tens.shape,
3998 output_dtype=result_tens.dtype,
3999 qinfo = qinfo,
4000 result_tensor = result_tens,
4001 input_list=input_list,
4002 output_list=output_list,
4003 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004004 ):
4005 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004006
Kevin Cheng550ccc52021-03-03 11:21:43 -08004007 self.ser.addOperator(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004008 op['op'], input_list, output_list, None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08004009 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004010 return result_tens
4011
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004012 def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
4013 result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
4014
4015 # Invalidate Input/Output list for error if checks.
4016 input_list = [a.name, b.name]
4017 output_list = [result_tens.name]
4018 pCount, cCount = op["operands"]
4019 num_operands = pCount + cCount
4020 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4021
Les Bell729b0352021-11-24 10:28:21 +00004022 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004023 self.ser,
4024 validator_fcns,
4025 error_name,
4026 op=op,
4027 input_shape=a.shape,
4028 input_dtype=a.dtype,
4029 input2_shape=b.shape,
4030 input2_dtype=b.dtype,
4031 output_shape=result_tens.shape,
4032 output_dtype=result_tens.dtype,
4033 qinfo = qinfo,
4034 result_tensor = result_tens,
4035 input_list=input_list,
4036 output_list=output_list,
4037 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004038 ):
4039 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004040
4041 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07004042 return result_tens
4043
Matthew Haddond6ce7252021-09-29 15:35:44 +01004044 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
4045 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
4046
4047 # Invalidate Input/Output list for error if checks.
4048 input_list = [a.name]
4049 output_list = [result_tens.name]
4050 pCount, cCount = op["operands"]
4051 num_operands = pCount + cCount
4052 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4053
Les Bell729b0352021-11-24 10:28:21 +00004054 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01004055 self.ser,
4056 validator_fcns,
4057 error_name,
4058 op=op,
4059 axis = axis,
4060 input_shape = a.shape,
4061 output_shape = result_tens.shape,
4062 input_dtype = a.dtype,
4063 output_dtype = result_tens.dtype,
4064 result_tensor = result_tens,
4065 input_list=input_list,
4066 output_list=output_list,
4067 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004068 ):
4069 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004070
4071 attr = ts.TosaSerializerAttribute()
4072 attr.AxisAttribute(axis)
4073
Matthew Haddond6ce7252021-09-29 15:35:44 +01004074 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004075 return result_tens
4076
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004077 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
4078 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004079
Jeremy Johnson18e26662021-07-22 16:15:29 +01004080 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07004081
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004082 if error_name == ErrorIf.MaxSmallerMin:
4083 # Make sure the numbers are different to invoke this error
4084 while v[0] == v[1]:
4085 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
4086 max_val = min(v)
4087 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07004088 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004089 max_val = max(v)
4090 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07004091
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004092 # Invalidate Input/Output list for error if checks.
4093 input_list = [a.name]
4094 output_list = [result_tens.name]
4095 pCount, cCount = op["operands"]
4096 num_operands = pCount + cCount
4097 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4098
Les Bell729b0352021-11-24 10:28:21 +00004099 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004100 self.ser,
4101 validator_fcns,
4102 error_name,
4103 op=op,
4104 max_val=max_val,
4105 min_val=min_val,
4106 input_shape = a.shape,
4107 output_shape = result_tens.shape,
4108 input_dtype = a.dtype,
4109 output_dtype = result_tens.dtype,
4110 result_tensor = result_tens,
4111 input_list=input_list,
4112 output_list=output_list,
4113 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004114 ):
4115 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004116
4117 attr = ts.TosaSerializerAttribute()
4118 if a.dtype == DType.FLOAT:
4119 attr.ClampAttribute(0, 0, min_val, max_val)
4120 else:
4121 attr.ClampAttribute(min_val, max_val, 0, 0)
4122
4123 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004124 return result_tens
4125
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004126 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
4127 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004128 attr = ts.TosaSerializerAttribute()
4129
4130 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
4131
Matthew Haddon848efb42021-09-09 12:30:53 +01004132 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004133 return result_tens
4134
4135 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004136 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
4137 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004138
Matthew Haddon848efb42021-09-09 12:30:53 +01004139 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07004140 return result_tens
4141
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004142 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
4143 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
4144
4145 # Invalidate Input/Output list for error if checks.
4146 input_list = [a.name]
4147 output_list = [result_tens.name]
4148 pCount, cCount = op["operands"]
4149 num_operands = pCount + cCount
4150 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4151
Les Bell729b0352021-11-24 10:28:21 +00004152 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004153 self.ser,
4154 validator_fcns,
4155 error_name,
4156 op=op,
4157 input_shape = a.shape,
4158 output_shape = result_tens.shape,
4159 input_dtype = a.dtype,
4160 output_dtype = result_tens.dtype,
4161 result_tensor = result_tens,
4162 input_list=input_list,
4163 output_list=output_list,
4164 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004165 ):
4166 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004167
4168 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004169 return result_tens
4170
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004171 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
4172 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
4173
4174 # Invalidate Input/Output list for error if checks.
4175 input_list = [a.name]
4176 output_list = [result_tens.name]
4177 pCount, cCount = op["operands"]
4178 num_operands = pCount + cCount
4179 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4180
Les Bell729b0352021-11-24 10:28:21 +00004181 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004182 self.ser,
4183 validator_fcns,
4184 error_name,
4185 op=op,
4186 input_shape = a.shape,
4187 output_shape = result_tens.shape,
4188 input_dtype = a.dtype,
4189 output_dtype = result_tens.dtype,
4190 result_tensor = result_tens,
4191 input_list=input_list,
4192 output_list=output_list,
4193 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004194 ):
4195 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004196
4197 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004198 return result_tens
4199
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004200 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
4201 if error_name != ErrorIf.WrongInputType:
4202 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01004203
4204 # To store variable length list of input tensors we need to store axis along with it
4205 axis = a[-1]
4206 a = a[:-1]
4207
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004208 result_tens = OutputShaper.concatOp(self.ser, self.rng, axis, *a, error_name=error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004209
Matthew Haddon818ab902021-07-27 09:12:49 +01004210 input_tensor_names = []
4211 for tensor in a:
4212 input_tensor_names.append(tensor.name)
4213
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004214 # Invalidate Input/Output list for error if checks.
4215 input_list = input_tensor_names
4216 output_list = [result_tens.name]
4217 pCount, cCount = op["operands"]
4218 num_operands = pCount + cCount
4219 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4220
Les Bell729b0352021-11-24 10:28:21 +00004221 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004222 self.ser,
4223 validator_fcns,
4224 error_name,
4225 op=op,
4226 axis=axis,
4227 input_shape = a[0].shape,
4228 output_shape = result_tens.shape,
4229 input_dtype = a[0].dtype,
4230 output_dtype = result_tens.dtype,
4231 inputs=a,
4232 result_tensor = result_tens,
4233 input_list=input_list,
4234 output_list=output_list,
4235 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004236 ):
4237 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004238
4239 attr = ts.TosaSerializerAttribute()
4240 attr.AxisAttribute(axis)
4241
4242
4243 self.ser.addOperator(op['op'], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01004244 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07004245
Kevin Chengfe392ce2021-10-18 21:51:55 +00004246 def build_pad(self, op, a, padding, pad_const_int, pad_const_float, validator_fcns=None, error_name=None, qinfo=None):
Matthew Haddone807aae2021-10-11 18:12:58 +01004247 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004248
Kevin Chengfe392ce2021-10-18 21:51:55 +00004249 attr = ts.TosaSerializerAttribute()
4250 attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07004251
Matthew Haddone807aae2021-10-11 18:12:58 +01004252 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00004253 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01004254 output_list = [result_tens.name]
4255 pCount, cCount = op["operands"]
4256 num_operands = pCount + cCount
4257 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4258
Les Bell729b0352021-11-24 10:28:21 +00004259 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01004260 self.ser,
4261 validator_fcns,
4262 error_name,
4263 op=op,
4264 input_shape = a.shape,
4265 output_shape = result_tens.shape,
4266 input_dtype = a.dtype,
4267 output_dtype = result_tens.dtype,
4268 pad=padding,
4269 qinfo=qinfo,
4270 result_tensor = result_tens,
4271 input_list=input_list,
4272 output_list=output_list,
4273 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004274 ):
4275 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01004276
Kevin Cheng550ccc52021-03-03 11:21:43 -08004277 self.ser.addOperator(
Kevin Chengfe392ce2021-10-18 21:51:55 +00004278 op['op'], input_list, output_list, attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08004279 )
Matthew Haddone86fd342021-09-07 16:12:21 +01004280 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07004281
Matthew Haddone807aae2021-10-11 18:12:58 +01004282 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
4283 result_tens = OutputShaper.reshapeOp(self.ser, self.rng, a, newShape, error_name)
4284
4285 # Invalidate Input/Output list for error if checks.
4286 input_list = [a.name]
4287 output_list = [result_tens.name]
4288 pCount, cCount = op["operands"]
4289 num_operands = pCount + cCount
4290 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4291
Les Bell729b0352021-11-24 10:28:21 +00004292 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01004293 self.ser,
4294 validator_fcns,
4295 error_name,
4296 op=op,
4297 input_shape = a.shape,
4298 output_shape = result_tens.shape,
4299 input_dtype = a.dtype,
4300 output_dtype = result_tens.dtype,
4301 result_tensor = result_tens,
4302 input_list=input_list,
4303 output_list=output_list,
4304 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004305 ):
4306 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004307
4308 attr = ts.TosaSerializerAttribute()
4309 attr.ReshapeAttribute(newShape)
4310
Matthew Haddone807aae2021-10-11 18:12:58 +01004311 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004312 return result_tens
4313
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004314 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
4315 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
4316
4317 # Invalidate Input/Output list for error if checks.
4318 input_list = [a.name]
4319 output_list = [result_tens.name]
4320 pCount, cCount = op["operands"]
4321 num_operands = pCount + cCount
4322 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4323
Les Bell729b0352021-11-24 10:28:21 +00004324 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004325 self.ser,
4326 validator_fcns,
4327 error_name,
4328 op=op,
4329 axis=axis,
4330 input_shape = a.shape,
4331 output_shape = result_tens.shape,
4332 input_dtype = a.dtype,
4333 output_dtype = result_tens.dtype,
4334 result_tensor = result_tens,
4335 input_list=input_list,
4336 output_list=output_list,
4337 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004338 ):
4339 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004340
4341 attr = ts.TosaSerializerAttribute()
4342 attr.AxisAttribute(axis)
4343
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004344 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004345 return result_tens
4346
Matthew Haddone807aae2021-10-11 18:12:58 +01004347 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
4348 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004349
Kevin Chengfe392ce2021-10-18 21:51:55 +00004350 attr = ts.TosaSerializerAttribute()
4351 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07004352
Matthew Haddone807aae2021-10-11 18:12:58 +01004353 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00004354 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01004355 output_list = [result_tens.name]
4356 pCount, cCount = op["operands"]
4357 num_operands = pCount + cCount
4358 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4359
Les Bell729b0352021-11-24 10:28:21 +00004360 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01004361 self.ser,
4362 validator_fcns,
4363 error_name,
4364 op=op,
4365 input_shape = a.shape,
4366 output_shape = result_tens.shape,
4367 perms=perms,
4368 input_dtype = a.dtype,
4369 output_dtype = result_tens.dtype,
4370 result_tensor = result_tens,
4371 input_list=input_list,
4372 output_list=output_list,
4373 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004374 ):
4375 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01004376
4377
Kevin Chengfe392ce2021-10-18 21:51:55 +00004378 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004379 return result_tens
4380
Matthew Haddone807aae2021-10-11 18:12:58 +01004381 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
4382 result_tens = OutputShaper.sliceOp(self.ser, self.rng, a, start, size, error_name)
4383
4384 # Invalidate Input/Output list for error if checks.
4385 input_list = [a.name]
4386 output_list = [result_tens.name]
4387 pCount, cCount = op["operands"]
4388 num_operands = pCount + cCount
4389 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4390
Les Bell729b0352021-11-24 10:28:21 +00004391 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01004392 self.ser,
4393 validator_fcns,
4394 error_name,
4395 op=op,
4396 input_shape = a.shape,
4397 output_shape = result_tens.shape,
4398 input_dtype = a.dtype,
4399 output_dtype = result_tens.dtype,
4400 start=start,
4401 size=size,
4402 result_tensor = result_tens,
4403 input_list=input_list,
4404 output_list=output_list,
4405 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004406 ):
4407 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004408
4409 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01004410 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07004411
Matthew Haddone807aae2021-10-11 18:12:58 +01004412 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004413 return result_tens
4414
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004415 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
4416 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
4417
4418 # Invalidate Input/Output list for error if checks.
4419 input_list = [a.name]
4420 output_list = [result_tens.name]
4421 pCount, cCount = op["operands"]
4422 num_operands = pCount + cCount
4423 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4424
Les Bell729b0352021-11-24 10:28:21 +00004425 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004426 self.ser,
4427 validator_fcns,
4428 error_name,
4429 op=op,
4430 input_shape = a.shape,
4431 output_shape = result_tens.shape,
4432 input_dtype = a.dtype,
4433 output_dtype = result_tens.dtype,
4434 result_tensor = result_tens,
4435 input_list=input_list,
4436 output_list=output_list,
4437 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004438 ):
4439 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004440
4441 attr = ts.TosaSerializerAttribute()
4442 attr.TileAttribute(multiples)
4443
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004444 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004445 return result_tens
4446
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004447 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004448
4449 # Create a new indicies tensor
4450 # here with data that doesn't exceed the dimensions of the values tensor
4451
Kevin Cheng550ccc52021-03-03 11:21:43 -08004452 K = values.shape[1] # K
4453 W = self.randInt(
4454 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
4455 ) # W
4456 indicies_arr = np.int32(
4457 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
4458 ) # (N, W)
4459 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004460
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004461 result_tens = OutputShaper.gatherOp(self.ser, self.rng, values, indicies, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004462
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004463 # Invalidate Input/Output list for error if checks.
4464 input_list = [values.name, indicies.name]
4465 output_list = [result_tens.name]
4466 pCount, cCount = op["operands"]
4467 num_operands = pCount + cCount
4468 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4469
Les Bell729b0352021-11-24 10:28:21 +00004470 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004471 self.ser,
4472 validator_fcns,
4473 error_name,
4474 op=op,
4475 input_shape = values.shape,
4476 output_shape = result_tens.shape,
4477 input_dtype = values.dtype,
4478 output_dtype = result_tens.dtype,
4479 result_tensor = result_tens,
4480 input_list=input_list,
4481 output_list=output_list,
4482 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004483 ):
4484 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004485
4486 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004487
4488 return result_tens
4489
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004490 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08004491
4492 # Create a new indicies tensor
4493 # here with data that doesn't exceed the dimensions of the values_in tensor
4494
Kevin Cheng550ccc52021-03-03 11:21:43 -08004495 K = values_in.shape[1] # K
4496 W = input.shape[1] # W
4497 indicies_arr = np.int32(
4498 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
4499 ) # (N, W)
4500 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004501
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004502 result_tens = OutputShaper.scatterOp(self.ser, self.rng, values_in, indicies, input, error_name)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004503
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004504 # Invalidate Input/Output list for error if checks.
4505 input_list = [values_in.name, indicies.name, input.name]
4506 output_list = [result_tens.name]
4507 pCount, cCount = op["operands"]
4508 num_operands = pCount + cCount
4509 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4510
Les Bell729b0352021-11-24 10:28:21 +00004511 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004512 self.ser,
4513 validator_fcns,
4514 error_name,
4515 op=op,
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004516 input_shape = values_in.shape,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004517 output_shape = result_tens.shape,
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004518 input_dtype = values_in.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004519 output_dtype = result_tens.dtype,
4520 result_tensor = result_tens,
4521 input_list=input_list,
4522 output_list=output_list,
4523 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004524 ):
4525 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08004526
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004527 self.ser.addOperator(op['op'], input_list, output_list)
4528
Kevin Cheng77d0f762020-11-24 10:26:32 -08004529 return result_tens
4530
Matthew Haddon848efb42021-09-09 12:30:53 +01004531
Kevin Cheng550ccc52021-03-03 11:21:43 -08004532 def build_resize(
4533 self,
4534 op,
4535 input,
4536 mode,
4537 stride,
4538 offset,
4539 shift,
4540 stride_fp,
4541 offset_fp,
4542 output_dims,
4543 input_dtype,
4544 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01004545 validator_fcns,
4546 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004547 ):
4548 result_tens = OutputShaper.resizeOp(
4549 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004550 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004551 input,
4552 mode,
4553 stride,
4554 offset,
4555 shift,
4556 stride_fp,
4557 offset_fp,
4558 output_dims,
4559 input_dtype,
4560 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01004561 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08004562 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004563
Matthew Haddon848efb42021-09-09 12:30:53 +01004564 # Invalidate Input/Output list for error if checks.
4565 input_list = [input.name]
4566 output_list = [result_tens.name]
4567 pCount, cCount = op["operands"]
4568 num_operands = pCount + cCount
4569 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
Matthew Haddone86fd342021-09-07 16:12:21 +01004570
Les Bell729b0352021-11-24 10:28:21 +00004571 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01004572 self.ser,
4573 validator_fcns,
4574 error_name,
4575 op=op,
4576 mode=mode,
4577 shift=shift,
4578 input_dtype=input_dtype,
4579 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004580 input_shape=input.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01004581 output_shape=output_dims,
4582 offset=offset,
4583 offset_fp=offset_fp,
4584 stride=stride,
4585 stride_fp=stride_fp,
4586 input_list=input_list,
4587 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004588 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01004589 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004590 ):
4591 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01004592
Eric Kunzee5e26762020-10-13 16:11:07 -07004593 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08004594
Kevin Cheng550ccc52021-03-03 11:21:43 -08004595 attr.ResizeAttribute(
4596 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
4597 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004598
Matthew Haddon848efb42021-09-09 12:30:53 +01004599 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004600 return result_tens
4601
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004602 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
4603 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
4604 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004605 self.ser.addOperator(
4606 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
4607 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004608 return result_tens
4609
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004610 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07004611 self.ser.addOutputTensor(val)
4612 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07004613
4614 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004615 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
4616 result_tens = OutputShaper.typeConversionOp(self.ser, self.rng, val, out_dtype, error_name)
4617
4618 # Invalidate Input/Output list for error if checks.
4619 input_list = [val.name]
4620 output_list = [result_tens.name]
4621 pCount, cCount = op["operands"]
4622 num_operands = pCount + cCount
4623 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4624
Les Bell729b0352021-11-24 10:28:21 +00004625 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004626 self.ser,
4627 validator_fcns,
4628 error_name,
4629 op=op,
4630 input_shape = val.shape,
4631 output_shape = result_tens.shape,
4632 input_dtype = val.dtype,
4633 output_dtype = result_tens.dtype,
4634 result_tensor = result_tens,
4635 input_list=input_list,
4636 output_list=output_list,
4637 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004638 ):
4639 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004640
4641 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004642 return result_tens
4643
Matthew Haddonc2025212021-10-08 21:21:05 +01004644 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel, validator_fcns, error_name):
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004645 result_tens = OutputShaper.typeConversionOp(self.ser, self.rng, val, out_dtype, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004646
4647 if per_channel:
4648 nc = val.shape[-1]
4649 else:
4650 nc = 1
4651
4652 in_type_width = self.typeWidth(val.dtype)
4653 out_type_width = self.typeWidth(out_dtype)
4654
Kevin Cheng3a478572021-01-22 17:21:02 -08004655 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004656 input_zp = self.randInt(-128, 128)
4657 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07004658 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004659 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07004660 in_type_width = in_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01004661 elif error_name == ErrorIf.InputZeroPointNotZero:
4662 input_zp = self.randInt(-128, 128)
4663 if input_zp == 0:
4664 input_zp = input_zp + self.rng.integers(1, 10)
4665 in_type_width = in_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004666 else:
4667 input_zp = 0
4668
Kevin Cheng3a478572021-01-22 17:21:02 -08004669 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004670 output_zp = self.randInt(-128, 128)
4671 out_type_width = out_type_width + 1
4672 elif out_dtype == DType.UINT8:
4673 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07004674 out_type_width = out_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01004675 elif error_name == ErrorIf.OutputZeroPointNotZero:
4676 output_zp = self.randInt(-128, 128)
4677 if output_zp == 0:
4678 output_zp = output_zp + self.rng.integers(1, 10)
4679 out_type_width = out_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004680 else:
4681 output_zp = 0
4682
4683 # Calculate scale based on:
4684 # scale = a *(2^output_width)/(2^input_width))
4685
4686 a = np.float32(self.rng.random(size=[nc]))
4687 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
4688
4689 if scale32:
4690 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01004691 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07004692 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
4693 else:
4694 # Cap the scaling at 2^15 - 1 for scale16
4695 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
4696
Kevin Cheng550ccc52021-03-03 11:21:43 -08004697 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07004698
4699 multiplier_arr = np.int32(np.zeros(shape=[nc]))
4700 shift_arr = np.int32(np.zeros(shape=[nc]))
4701
4702 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004703 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
4704 scale_arr[i], scale32
4705 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004706
Kevin Cheng550ccc52021-03-03 11:21:43 -08004707 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07004708
Matthew Haddonc2025212021-10-08 21:21:05 +01004709 # Invalidate Input/Output list for error if checks.
4710 input_list = [val.name]
4711 output_list = [result_tens.name]
4712 pCount, cCount = op["operands"]
4713 num_operands = pCount + cCount
4714 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4715
4716 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00004717 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01004718 self.ser,
4719 validator_fcns,
4720 error_name,
4721 op=op,
4722 input_dtype=val.dtype,
4723 output_dtype=out_dtype,
4724 input_shape=val.shape,
4725 qinfo=qinfo,
4726 scale32 = scale32,
4727 double_round = double_round,
4728 input_list=input_list,
4729 output_list=output_list,
4730 result_tensor=result_tens,
4731 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004732 ):
4733 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01004734
Eric Kunzee5e26762020-10-13 16:11:07 -07004735 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004736 attr.RescaleAttribute(
4737 input_zp,
4738 output_zp,
4739 multiplier_arr,
4740 shift_arr,
4741 scale32,
4742 double_round,
4743 per_channel,
4744 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004745
Matthew Haddonc2025212021-10-08 21:21:05 +01004746 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004747 return result_tens
4748
Matthew Haddon630c17c2021-10-14 15:05:41 +01004749 def build_cond_if_const(self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004750 # For cond_if with constants, we're supplied with then/else tensors that we ignore
4751 # (except for the generated shap) and the condition. Build Then/Else blocks
4752 # and fill them with const nodes for the body.
4753
4754 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08004755 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07004756
4757 # Make then/else tensors
4758 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01004759
4760 # Create an incorrect output shape for error_if tests
4761 if error_name in [ErrorIf.CondIfOutputListThenGraphMismatch, ErrorIf.CondIfOutputListElseGraphMismatch]:
4762 incorrect_shape = deepcopy(then_tens.shape)
4763 for i in range(len(incorrect_shape)):
4764 incorrect_shape[i] = incorrect_shape[i] + self.rng.choice([-3, -2, 2, 3])
4765 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
4766
Jeremy Johnson18e26662021-07-22 16:15:29 +01004767 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
4768 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07004769
4770 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08004771 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07004772
4773 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08004774 then_block = "THEN_BLOCK"
4775 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07004776 attr = ts.TosaSerializerAttribute()
4777 attr.CondIfAttribute(then_block, else_block)
4778
4779 # Finally, build the op and the two blocks
Matthew Haddon848efb42021-09-09 12:30:53 +01004780 self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004781
4782 self.ser.startBasicBlock(then_block)
4783 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01004784 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
4785 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
4786 else:
4787 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004788 self.ser.addOutputTensor(then_tens)
4789
4790 self.ser.startBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004791 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
4792 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
4793 else:
4794 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004795 self.ser.addOutputTensor(else_tens)
4796
Les Bell729b0352021-11-24 10:28:21 +00004797 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01004798 self.ser,
4799 validator_fcns,
4800 error_name,
4801 op=op,
4802 basicBlocks=self.ser.basicBlocks
Les Bell729b0352021-11-24 10:28:21 +00004803 ):
4804 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01004805
Eric Kunzee5e26762020-10-13 16:11:07 -07004806 return result_tens
4807
Matthew Haddon630c17c2021-10-14 15:05:41 +01004808 def build_cond_if_binary(self, op, a, b, cond, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004809 # For cond_if with a binary op in the then/else blocks, take a and b and
4810 # alternately add or subtract them based on the condition
4811
4812 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08004813 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07004814
Kevin Cheng550ccc52021-03-03 11:21:43 -08004815 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004816
4817 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08004818 then_block = "THEN_BLOCK"
4819 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07004820 attr = ts.TosaSerializerAttribute()
4821 attr.CondIfAttribute(then_block, else_block)
4822
Matthew Haddon630c17c2021-10-14 15:05:41 +01004823 if error_name in [ErrorIf.CondIfInputListThenGraphMismatch, ErrorIf.CondIfInputListElseGraphMismatch,
4824 ErrorIf.CondIfOutputListElseGraphMismatch, ErrorIf.CondIfOutputListThenGraphMismatch]:
4825 incorrect_shape = a.shape.copy()
4826 for i in range(len(incorrect_shape)):
4827 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
4828 incorrect_block_input = deepcopy(a)
4829 incorrect_block_input.shape = incorrect_shape
4830
4831
Eric Kunzee5e26762020-10-13 16:11:07 -07004832 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08004833 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01004834 op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08004835 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004836
Les Bell6040b4d2021-10-11 12:50:31 +01004837 if a.dtype in (DType.FLOAT, DType.INT32):
4838 then_op, else_op = Op.ADD, Op.SUB
4839 elif a.dtype in (DType.INT8, DType.INT16):
4840 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
4841 else:
4842 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07004843
Les Bell6040b4d2021-10-11 12:50:31 +01004844 for block, op in ((then_block, then_op), (else_block, else_op)):
4845 self.ser.startBasicBlock(block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004846 if ((error_name == ErrorIf.CondIfInputListThenGraphMismatch and block == then_block) or
4847 (error_name == ErrorIf.CondIfInputListElseGraphMismatch and block == else_block)):
4848 self.ser.addInputTensor(incorrect_block_input)
4849 self.ser.addInputTensor(b)
4850 tens = self.ser.addOutput(a.shape, a.dtype)
4851 elif ((error_name == ErrorIf.CondIfOutputListThenGraphMismatch and block == then_block) or
4852 (error_name == ErrorIf.CondIfOutputListElseGraphMismatch and block == else_block)):
4853 self.ser.addInputTensor(a)
4854 self.ser.addInputTensor(b)
4855 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
4856 else:
4857 self.ser.addInputTensor(a)
4858 self.ser.addInputTensor(b)
4859 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01004860 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07004861
Les Bell729b0352021-11-24 10:28:21 +00004862 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01004863 self.ser,
4864 validator_fcns,
4865 error_name,
4866 op=op,
4867 a=a,
4868 b=b,
4869 basicBlocks=self.ser.basicBlocks
Les Bell729b0352021-11-24 10:28:21 +00004870 ):
4871 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01004872
Eric Kunzee5e26762020-10-13 16:11:07 -07004873 return result_tens
4874
Matthew Haddon630c17c2021-10-14 15:05:41 +01004875 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004876 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07004877
Kevin Cheng550ccc52021-03-03 11:21:43 -08004878 cond_block = "COND_BLOCK"
4879 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07004880
4881 attr = ts.TosaSerializerAttribute()
4882 attr.WhileLoopAttribute(cond_block, body_block)
4883
4884 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08004885 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004886 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08004887 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07004888
4889 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004890 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
4891 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004892 if error_name == ErrorIf.InputListOutputListMismatch:
4893 incorrect_acc = deepcopy(acc)
4894 for i in range(len(incorrect_acc.shape)):
4895 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
4896 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
4897 else:
4898 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004899
4900 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08004901 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01004902 op['op'],
Kevin Cheng550ccc52021-03-03 11:21:43 -08004903 [iter.name, a.name, acc.name],
4904 [iter_out.name, a_out.name, acc_out.name],
4905 attr,
4906 )
Kevin Chengb227ae52021-09-02 13:43:17 -07004907 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07004908
Matthew Haddon630c17c2021-10-14 15:05:41 +01004909 if error_name in [ErrorIf.InputListCondGraphMismatch, ErrorIf.InputListBodyGraphInputMismatch, ErrorIf.InputListBodyGraphOutputMismatch]:
4910 incorrect_iter = deepcopy(iter)
4911 for i in range(len(incorrect_iter.shape)):
4912 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
4913 if len(incorrect_iter.shape) == 0:
4914 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
4915
4916 incorrect_acc = deepcopy(acc)
4917 for i in range(len(incorrect_acc.shape)):
4918 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
4919
Eric Kunzee5e26762020-10-13 16:11:07 -07004920 # COND block (input: iter, output: cond_tens )
4921 self.ser.startBasicBlock(cond_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004922 if error_name == ErrorIf.InputListCondGraphMismatch:
4923 self.ser.addInputTensor(incorrect_iter)
4924 self.ser.addInputTensor(a)
4925 self.ser.addInputTensor(incorrect_acc)
4926 else:
4927 self.ser.addInputTensor(iter)
4928 self.ser.addInputTensor(a)
4929 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004930 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01004931
4932 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
4933 cond_tens = self.ser.addOutput([], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT]))
4934 else:
4935 cond_tens = self.ser.addOutput([], DType.BOOL)
4936
Kevin Cheng550ccc52021-03-03 11:21:43 -08004937 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07004938
4939 # BODY block (input: a, acc, iter, output: a, acc, iter)
4940 # Note that local intermediate tensors need to be declared here for the outputs
4941 self.ser.startBasicBlock(body_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004942 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
4943 self.ser.addInputTensor(incorrect_iter)
4944 self.ser.addInputTensor(a)
4945 self.ser.addInputTensor(incorrect_acc)
4946 else:
4947 self.ser.addInputTensor(iter)
4948 self.ser.addInputTensor(a)
4949 self.ser.addInputTensor(acc)
4950
Kevin Cheng550ccc52021-03-03 11:21:43 -08004951 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01004952
4953 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
4954 iter_body_out = self.ser.addIntermediate(incorrect_iter.shape, incorrect_iter.dtype)
4955 acc_body_out = self.ser.addIntermediate(incorrect_acc.shape, incorrect_acc.dtype)
4956 else:
4957 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
4958 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
4959
Eric Kunzee5e26762020-10-13 16:11:07 -07004960 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
4961 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
4962 self.ser.addOutputTensor(iter_body_out)
4963 self.ser.addOutputTensor(a)
4964 self.ser.addOutputTensor(acc_body_out)
4965
Les Bell729b0352021-11-24 10:28:21 +00004966 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01004967 self.ser,
4968 validator_fcns,
4969 error_name,
4970 op=op,
4971 basicBlocks=self.ser.basicBlocks
Les Bell729b0352021-11-24 10:28:21 +00004972 ):
4973 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01004974
Eric Kunzee5e26762020-10-13 16:11:07 -07004975 return acc_out
4976
Matthew Haddon1c00b712021-10-01 15:51:03 +01004977 def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
4978 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
4979 default_test_rank_range = range(1, 5)
4980 if not shapeFilter:
4981 shapeFilter = [None]
4982
4983 # Calculate the filters based on what is requested and what the operator allows
4984 rmin, rmax = op["rank"]
4985 if rankFilter is not None:
4986 cleanRankFilter = []
4987 # Ensure rankFilter values are allowed by operator
4988 for rank in rankFilter:
4989 if rank >= rmin and rank <= rmax:
4990 cleanRankFilter.append(rank)
4991 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01004992 # Ensure default behaviour is bounded by default range or by operator,
4993 # whichever is the smaller range of ranks.
4994 opRankRange = range(rmin, rmax + 1)
4995 cleanRankFilter = opRankRange if len(opRankRange) <= len(default_test_rank_range) else default_test_rank_range
Matthew Haddon1c00b712021-10-01 15:51:03 +01004996 else:
4997 cleanRankFilter = range(rmin, rmax + 1)
4998
4999 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005000
Matthew Haddon1c00b712021-10-01 15:51:03 +01005001 if dtypeFilter is not None:
5002 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01005003 # Create list of operator dtypes filtered by requested dtypes
5004 for dtype in dtypes:
5005 if dtype in dtypeFilter or (isinstance(dtype, list) and dtype[0] in dtypeFilter):
Matthew Haddon1c00b712021-10-01 15:51:03 +01005006 cleanDtypeFilter.append(dtype)
5007 else:
5008 cleanDtypeFilter = dtypes
5009
5010 if testType == 'positive':
5011 filterDict = {
5012 'shapeFilter': shapeFilter,
5013 'rankFilter': cleanRankFilter,
5014 'dtypeFilter': cleanDtypeFilter
5015 }
5016 return filterDict
5017 elif testType == 'negative':
Matthew Haddone807aae2021-10-11 18:12:58 +01005018 if validator is not None:
5019 validator_info = validator(check=False, op=op)
5020 else:
5021 return None
5022
Matthew Haddon1c00b712021-10-01 15:51:03 +01005023 error_arguments = validator_info['param_reqs']
5024
5025 #Set parameters as required
5026 if error_arguments['rank'] != None:
5027 rankFilter = error_arguments['rank']
5028 else:
5029 rankFilter = cleanRankFilter
5030
5031 if error_arguments['dtype'] != None:
5032 dtypeFilter = error_arguments['dtype']
5033 else:
5034 dtypeFilter = cleanDtypeFilter
5035
5036 if error_arguments['shape'] != None:
5037 shapeFilter = error_arguments['shape']
5038 else:
5039 shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small
5040
5041 filterDict = {
5042 'shapeFilter': shapeFilter,
5043 'rankFilter': rankFilter,
5044 'dtypeFilter': dtypeFilter
5045 }
5046 return filterDict
5047
5048
Kevin Cheng550ccc52021-03-03 11:21:43 -08005049 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01005050 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08005051 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005052
5053 try:
5054 op = self.TOSA_OP_LIST[opName]
5055 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005056 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07005057
5058 # Initialize a new random number generator
5059 self.rng = np.random.default_rng(self.random_seed)
5060
Kevin Cheng550ccc52021-03-03 11:21:43 -08005061 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005062
Eric Kunzee5e26762020-10-13 16:11:07 -07005063 # Test list consists of a tuple of:
5064 # (opName, testNameStr, dtype, shapeList, argumentsList)
5065 testList = []
Matthew Haddon1c00b712021-10-01 15:51:03 +01005066 if testType == 'negative' and "error_if_validators" in op:
5067 error_if_validators = op["error_if_validators"]
5068 else:
5069 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07005070
Matthew Haddon1c00b712021-10-01 15:51:03 +01005071 for validator in error_if_validators:
5072 if validator is not None:
5073 error_name = validator(check=False, op=op)['error_name']
Matthew Haddon1c00b712021-10-01 15:51:03 +01005074 else:
5075 error_name = None
5076
5077 filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator)
Matthew Haddone807aae2021-10-11 18:12:58 +01005078 if filterDict == None:
5079 return []
Matthew Haddon1c00b712021-10-01 15:51:03 +01005080 cleanRankFilter = filterDict['rankFilter']
5081 cleanDtypeFilter = filterDict['dtypeFilter']
5082 cleanShapeFilter = filterDict['shapeFilter']
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005083 #print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01005084
5085 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01005086 for t in cleanDtypeFilter:
5087 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01005088 # Filter out by rank
5089 if shape is not None and len(shape) != r:
5090 continue
Matthew Haddon74567092021-07-16 15:38:20 +01005091 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005092 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07005093
Matthew Haddon74567092021-07-16 15:38:20 +01005094 shapeStr = self.shapeStr(shapeList[0])
5095 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07005096
Matthew Haddon74567092021-07-16 15:38:20 +01005097 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
5098 argList = []
5099 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01005100 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07005101 else:
Matthew Haddon74567092021-07-16 15:38:20 +01005102 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07005103
Matthew Haddon74567092021-07-16 15:38:20 +01005104 for argStr, args in argList:
Matthew Haddon1c00b712021-10-01 15:51:03 +01005105 if testType == 'positive':
5106 if argStr:
5107 testStr = "{}_{}_{}_{}".format(
5108 opName, shapeStr, typeStr, argStr
5109 )
5110 else:
5111 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
5112 elif testType == 'negative':
Matthew Haddone86fd342021-09-07 16:12:21 +01005113 if argStr:
5114 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
5115 opName, error_name, shapeStr, typeStr, argStr
5116 )
5117 else:
5118 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005119
5120 testList.append((opName, testStr, t, error_name, shapeList, args))
5121
5122 if testType == 'positive':
5123 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
5124 if "invalid_test_validators" in op:
5125 invalid_test_validators = op["invalid_test_validators"]
5126 clean_testList = []
5127 for test in testList:
5128 for validator_fcn in invalid_test_validators:
5129 remove_test = False
5130 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
5131 remove_test = True
5132 if not remove_test:
5133 clean_testList.append(test)
5134 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07005135
5136 return testList
5137
Matthew Haddone86fd342021-09-07 16:12:21 +01005138
5139 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07005140 try:
5141 op = self.TOSA_OP_LIST[opName]
5142 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005143 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07005144
5145 # Create a serializer
5146 self.createSerializer(opName, testStr)
5147
Kevin Cheng550ccc52021-03-03 11:21:43 -08005148 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01005149 if "error_if_validators" in op:
5150 error_if_validators = op["error_if_validators"]
5151 else:
5152 error_if_validators = None
5153
Kevin Cheng550ccc52021-03-03 11:21:43 -08005154 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07005155 num_operands = pCount + cCount
5156
5157 if isinstance(dtype_or_dtypeList, list):
5158 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07005159 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01005160 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07005161 else:
5162 dtypeList = [dtype_or_dtypeList] * (num_operands)
5163
Kevin Cheng93a16282021-08-31 16:14:03 -07005164 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01005165 assert (
5166 len(shapeList) == num_operands
5167 ), "shapeList length {} must match number of operands {}".format(
5168 len(shapeList), num_operands
5169 )
5170 assert (
5171 len(dtypeList) == num_operands
5172 ), "dtypeList length {} must match number of operands {}".format(
5173 len(dtypeList), num_operands
5174 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005175
5176 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005177 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005178 except KeyError:
5179 qgen = None
5180
5181 # Build the random tensor operands and the test
5182 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08005183
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005184 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005185
5186 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005187 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005188 else:
5189 qinfo = None
5190
5191 try:
5192 if error_if_validators is None:
5193 if qinfo is not None:
5194 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
5195 else:
5196 resultName = build_fcn(self, op, *tens, *testArgs)
5197 else:
5198 if qinfo is not None:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005199 resultName = build_fcn(self, op, *tens, *testArgs, validator_fcns=error_if_validators, error_name=error_name, qinfo=qinfo)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005200 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005201 resultName = build_fcn(self, op, *tens, *testArgs, validator_fcns=error_if_validators, error_name=error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005202 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00005203 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01005204 raise e
5205
Les Bell729b0352021-11-24 10:28:21 +00005206 if resultName:
5207 # The test is valid, serialize it
5208 self.serialize("test")
5209 else:
5210 # The test is not valid
5211 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01005212
5213
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005214 def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
Matthew Haddon1c00b712021-10-01 15:51:03 +01005215 pCount, cCount = op["operands"]
5216
5217 tens = []
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005218 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name == None:
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005219 # Make sure the operation does not cause value saturation - where
5220 # the number wraps due to limited number of bits to store the answer
5221 assert (
5222 pCount == 2 and cCount == 0
5223 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005224 placeholders = []
5225 add = (op["op"] == Op.ADD)
5226 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
5227 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
5228 if add:
5229 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
5230 else:
5231 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
5232
5233 # Work out the saturation limits
5234 max_i32 = (1 << 31)-1
5235 min_i32 = -(1 << 31)
5236 max_arr = np.full(shapeList[1], max_i32)
5237 min_arr = np.full(shapeList[1], min_i32)
5238
5239 # Find how much values exceed the maximum/minimums
5240 sat_max_arr = np.maximum(res_arr - max_arr, 0)
5241 sat_min_arr = np.minimum(res_arr - min_arr, 0)
5242
5243 if not add:
5244 # Swap saturation values and negate values as we need to perform opposite operations
5245 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
5246
5247 # Create new array of unsaturated values by clipping values as needed
5248 b_unsat_arr = b_arr
5249 if (sat_max_arr != 0).any():
5250 # Clip values that cause saturation
5251 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
5252 # Reduce axes in unsaturated tensor to match original tensor
5253 for axis, dim in enumerate(b_arr.shape):
5254 if dim != b_unsat_arr.shape[axis]:
5255 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
5256 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
5257
5258 if (sat_min_arr != 0).any():
5259 # Clip values that cause saturation
5260 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
5261 # Reduce axes in unsaturated tensor to match original tensor
5262 for axis, dim in enumerate(b_arr.shape):
5263 if dim != b_unsat_arr.shape[axis]:
5264 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
5265 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
5266
5267 placeholders.append(
5268 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
5269 )
5270 placeholders.append(
5271 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
5272 )
5273
5274 tens.extend(placeholders)
Jeremy Johnson8c06a652021-10-20 15:51:11 +01005275 elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[0] == DType.INT32:
5276 # Limit input tensors with cond_if_binary or while_loop to stop
5277 # saturation of add/sub ops
5278 pRemain = pCount
5279 placeholders = []
5280 for idx, shape in enumerate(shapeList[:]):
5281 arr = self.getRandTensor(shapeList[idx], DType.INT16)
5282 if pRemain > 0:
5283 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
5284 pRemain -= 1
5285 else:
5286 placeholders.append(self.ser.addConst(shape, dtypeList[idx], arr))
5287
5288 tens.extend(placeholders)
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005289 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
5290 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005291 assert (
5292 pCount == 2 and cCount == 0
5293 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08005294
5295 placeholders = []
5296 for idx, shape in enumerate(shapeList[:]):
5297 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07005298 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08005299 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07005300 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08005301 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07005302 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08005303 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005304 elif error_name == ErrorIf.WrongInputType:
5305 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Chengaee1fac2020-11-11 13:54:06 -08005306 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005307 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08005308 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005309 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07005310 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08005311
5312 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01005313 elif op["op"] == Op.SELECT:
5314 # Set datatype of condition tensor to boolean
5315 dtypeList[0] = DType.BOOL
5316 tens.extend(
5317 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
5318 )
5319 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005320 elif op["op"] == Op.INTDIV and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005321 assert (
5322 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01005323 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005324
5325 placeholders = []
5326
Matthew Haddon459443c2021-08-23 16:43:13 +01005327 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005328 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07005329 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005330 while True:
5331 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
5332 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
5333
5334 if (divisor_arr == 0).any():
5335 continue
5336
Kevin Cheng47315e12021-05-13 17:41:28 -07005337 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005338 continue
5339
5340 break
5341
5342 placeholders.append(
5343 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
5344 )
5345 placeholders.append(
5346 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
5347 )
5348
5349 tens.extend(placeholders)
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005350 elif op["op"] == Op.MUL and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005351 assert (
5352 pCount == 2 and cCount == 0
5353 ), "Op.MUL must have 2 placeholders, 0 consts"
5354
5355 if dtypeList[0] == DType.FLOAT:
5356 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
5357 else:
5358 placeholders = []
5359
5360 # Make sure multiply result in int32 range
5361 shift = testArgs[0]
5362 if dtypeList[0] == DType.INT8:
5363 num_bits = 8
5364 elif dtypeList[0] == DType.INT16:
5365 num_bits = 16
5366 elif dtypeList[0] == DType.INT32:
5367 num_bits = 32
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005368 elif error_name == ErrorIf.WrongInputType:
5369 num_bits = 8
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005370 else:
5371 raise Exception("OpMul: invalid input dtype")
5372
5373 for idx, shape in enumerate(shapeList[:]):
5374 low = -(2 ** (num_bits - 1))
5375 high = (2 ** (num_bits - 1)) - 1
5376
5377 a_arr = np.int32(
5378 self.rng.integers(low=low, high=high, size=shapeList[0])
5379 )
5380 b_arr = np.int32(
5381 self.rng.integers(low=low, high=high, size=shapeList[1])
5382 )
5383
5384 i = 0
5385 while True:
5386
5387 a_arr_64 = a_arr.astype(np.int64)
5388 b_arr_64 = b_arr.astype(np.int64)
5389
5390 if shift > 0:
5391 rounding = 1 << (shift - 1)
5392 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
5393 else:
5394 result_arr = a_arr_64 * b_arr_64
5395
5396 if (result_arr > -(2 ** 31)).all() and (
5397 result_arr <= ((2 ** 31) - 1)
5398 ).all():
5399 break
5400
5401 i = i + 1
5402 a_arr = a_arr // 2
5403 b_arr = b_arr // 2
5404
5405 placeholders.append(
5406 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
5407 )
5408 placeholders.append(
5409 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
5410 )
5411
5412 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01005413 elif op["op"] == Op.CONCAT:
5414 count = len(shapeList) - self.args.num_const_inputs_concat
5415 if count < 1:
5416 count = 1
5417 if self.args.num_const_inputs_concat == 0:
5418 count = len(shapeList)
5419
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005420 # Ensure axis is an int
5421 testArgs[0] = int(testArgs[0])
5422
5423 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0], error_name)
5424
Matthew Haddon818ab902021-07-27 09:12:49 +01005425 tens.extend(
5426 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
5427 )
5428 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08005429 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07005430 tens.extend(
5431 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
5432 )
5433 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07005434
Matthew Haddon1c00b712021-10-01 15:51:03 +01005435 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07005436
5437 def createDynamicOpLists(self):
5438
5439 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07005440 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005441
Kevin Cheng1533b852021-09-01 12:51:58 -07005442 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005443 testName = "conv2d_{}x{}".format(k[0], k[1])
5444 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
5445 self.TOSA_OP_LIST[testName]["filter"] = k
5446 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07005447
Kevin Cheng550ccc52021-03-03 11:21:43 -08005448 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
5449 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
5450 "depthwise_conv2d_TEMPLATE"
5451 ].copy()
5452 self.TOSA_OP_LIST[testName]["filter"] = k
5453 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07005454
Kevin Cheng550ccc52021-03-03 11:21:43 -08005455 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
5456 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
5457 "transpose_conv2d_TEMPLATE"
5458 ].copy()
5459 self.TOSA_OP_LIST[testName]["filter"] = k
5460 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07005461
Kevin Cheng1533b852021-09-01 12:51:58 -07005462 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
5463 for k in KERNELS_3D:
5464 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
5465 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
5466 self.TOSA_OP_LIST[testName]["filter"] = k
5467 self.TOSA_OP_LIST[testName]["template"] = False
5468
Eric Kunzee5e26762020-10-13 16:11:07 -07005469 # Delete any templates after having created any dynamic ops
5470 # This is a two-pass operation because it's bad practice to delete
5471 # keys from dictionaries while iterating
5472 keyList = []
5473 for k in self.TOSA_OP_LIST:
5474 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005475 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07005476 keyList.append(k)
5477 continue
5478 except KeyError:
5479 pass
5480
5481 for k in keyList:
5482 del self.TOSA_OP_LIST[k]
5483
5484 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005485 """Fill in default fields for ops if they aren't already specified.
5486 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07005487 for op in self.TOSA_OP_LIST:
5488
5489 # Required fields
5490 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005491 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005492 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005493 raise Exception(
5494 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
5495 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005496
5497 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005498 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005499 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005500 raise Exception(
5501 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
5502 op
5503 )
5504 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005505
5506 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005507 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005508 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005509 raise Exception(
5510 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
5511 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005512
5513 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005514 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005515 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005516 raise Exception(
5517 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
5518 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005519
5520 # Put in default rank range, if missing
5521 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005522 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005523 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005524 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07005525
5526 # Tensor operator list
5527 # 'op': op name
5528 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08005529 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
5530 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07005531 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
5532 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08005533 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07005534
Kevin Cheng550ccc52021-03-03 11:21:43 -08005535 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
5536 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07005537
Kevin Cheng550ccc52021-03-03 11:21:43 -08005538 TYPE_BOOL = [DType.BOOL]
5539 TYPE_FI32 = [DType.FLOAT, DType.INT32]
5540 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
5541 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07005542
Kevin Cheng550ccc52021-03-03 11:21:43 -08005543 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07005544
Kevin Cheng1533b852021-09-01 12:51:58 -07005545 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07005546 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07005547 [DType.INT8, DType.INT8, DType.INT32],
5548 [DType.INT16, DType.INT8, DType.INT48],
5549 DType.FLOAT,
5550 ]
5551
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01005552 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07005553
5554 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08005555 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08005556 "argmax": {
5557 "op": Op.ARGMAX,
5558 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005559 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005560 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
5561 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005562 "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evArgmaxOutputRankMismatch,
5563 TosaErrorValidator.evArgmaxOutputShapeMismatch, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
5564 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005565 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005566 "avg_pool2d": {
5567 "op": Op.AVG_POOL2D,
5568 "operands": (1, 0),
5569 "rank": (4, 4),
5570 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
5571 "qgen": TosaQuantGen.qgUnary,
5572 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00005573 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005574 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
5575 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5576 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
5577 TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005578 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005579 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08005580 "conv2d_TEMPLATE": {
5581 "op": Op.CONV2D,
5582 "operands": (1, 2),
5583 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01005584 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005585 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07005586 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00005587 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
5588 "error_if_validators": (
5589 TosaErrorValidator.evWrongInputType,
5590 TosaErrorValidator.evWrongOutputType,
5591 TosaErrorValidator.evWrongInputList,
5592 TosaErrorValidator.evWrongOutputList,
5593 TosaErrorValidator.evInputZeroPointNotZero,
5594 TosaErrorValidator.evWeightZeroPointNotZero,
5595 TosaErrorValidator.evPadSmallerZero,
5596 TosaErrorValidator.evStrideSmallerOne,
5597 TosaErrorValidator.evDilationSmallerOne,
5598 TosaErrorValidator.evWrongRank,
5599 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005600 "template": True,
5601 },
Kevin Cheng1533b852021-09-01 12:51:58 -07005602 # Templated operator. Filled in by createDynamicOpLists
5603 "conv3d_TEMPLATE": {
5604 "op": Op.CONV3D,
5605 "operands": (1, 2),
5606 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01005607 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07005608 "qgen": TosaQuantGen.qgConv,
5609 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00005610 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
5611 "error_if_validators": (
5612 TosaErrorValidator.evWrongInputType,
5613 TosaErrorValidator.evWrongOutputType,
5614 TosaErrorValidator.evWrongInputList,
5615 TosaErrorValidator.evWrongOutputList,
5616 TosaErrorValidator.evInputZeroPointNotZero,
5617 TosaErrorValidator.evWeightZeroPointNotZero,
5618 TosaErrorValidator.evPadSmallerZero,
5619 TosaErrorValidator.evStrideSmallerOne,
5620 TosaErrorValidator.evDilationSmallerOne,
5621 TosaErrorValidator.evWrongRank,
5622 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07005623 "template": True,
5624 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005625 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08005626 "depthwise_conv2d_TEMPLATE": {
5627 "op": Op.DEPTHWISE_CONV2D,
5628 "operands": (1, 2),
5629 "filter": [1, 1],
5630 "rank": (4, 4),
5631 "build_fcn": (
5632 build_depthwise_conv2d,
5633 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01005634 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005635 ),
5636 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07005637 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00005638 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
5639 "error_if_validators": (
5640 TosaErrorValidator.evWrongInputType,
5641 TosaErrorValidator.evWrongOutputType,
5642 TosaErrorValidator.evWrongInputList,
5643 TosaErrorValidator.evWrongOutputList,
5644 TosaErrorValidator.evInputZeroPointNotZero,
5645 TosaErrorValidator.evWeightZeroPointNotZero,
5646 TosaErrorValidator.evPadSmallerZero,
5647 TosaErrorValidator.evStrideSmallerOne,
5648 TosaErrorValidator.evDilationSmallerOne,
5649 TosaErrorValidator.evWrongRank,
5650 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005651 "template": True,
5652 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005653 "fully_connected": {
5654 "op": Op.FULLY_CONNECTED,
5655 "operands": (1, 2),
5656 "rank": (2, 2),
5657 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
5658 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07005659 "types": TYPE_CONV,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005660 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evWrongRank,
5661 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005662 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005663 "matmul": {
5664 "op": Op.MATMUL,
5665 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07005666 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08005667 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
5668 "qgen": TosaQuantGen.qgMatmul,
5669 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005670 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
5671 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005672 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005673 "max_pool2d": {
5674 "op": Op.MAX_POOL2D,
5675 "operands": (1, 0),
5676 "rank": (4, 4),
5677 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
5678 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00005679 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005680 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
5681 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5682 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005683 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005684 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08005685 "transpose_conv2d_TEMPLATE": {
5686 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07005687 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005688 "rank": (4, 4),
5689 "build_fcn": (
5690 build_transpose_conv2d,
5691 TosaTensorGen.tgTransposeConv2D,
5692 TosaArgGen.agTransposeConv2D,
5693 ),
5694 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07005695 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00005696 "invalid_test_validators": (
5697 TosaInvalidValidator.ivHeightWidthInvalid,
5698 TosaInvalidValidator.ivNonPositiveOutputShape,
5699 ),
5700 "error_if_validators": (
5701 TosaErrorValidator.evWrongInputType,
5702 TosaErrorValidator.evWrongOutputType,
5703 TosaErrorValidator.evWrongInputList,
5704 TosaErrorValidator.evWrongOutputList,
5705 TosaErrorValidator.evInputZeroPointNotZero,
5706 TosaErrorValidator.evWeightZeroPointNotZero,
5707 TosaErrorValidator.evPadSmallerZero,
5708 TosaErrorValidator.evStrideSmallerOne,
5709 TosaErrorValidator.evDilationSmallerOne,
5710 TosaErrorValidator.evWrongRank,
5711 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005712 "template": True,
5713 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005714 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08005715 "clamp": {
5716 "op": Op.CLAMP,
5717 "operands": (1, 0),
5718 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
5719 "types": TYPE_NARROW_INT_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005720 "error_if_validators": (TosaErrorValidator.evMaxSmallerMin, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5721 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005722 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08005723 "sigmoid": {
5724 "op": Op.SIGMOID,
5725 "operands": (1, 0),
5726 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
5727 "types": TYPE_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005728 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5729 TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005730 },
5731 "tanh": {
5732 "op": Op.TANH,
5733 "operands": (1, 0),
5734 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
5735 "types": TYPE_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005736 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5737 TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005738 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005739 # Elementwise Binary Operators
5740 "add": {
5741 "op": Op.ADD,
5742 "operands": (2, 0),
5743 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5744 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005745 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005746 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005747 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005748 "arithmetic_right_shift": {
5749 "op": Op.ARITHMETIC_RIGHT_SHIFT,
5750 "operands": (2, 0),
5751 "build_fcn": (
5752 build_arithmetic_right_shift,
5753 TosaTensorGen.tgBroadcastFuzz,
5754 TosaArgGen.agArithmeticRightShift,
5755 ),
5756 "types": TYPE_INT,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005757 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5758 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005759 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005760 "bitwise_and": {
5761 "op": Op.BITWISE_AND,
5762 "operands": (2, 0),
5763 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5764 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005765 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005766 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005767 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005768 "bitwise_or": {
5769 "op": Op.BITWISE_OR,
5770 "operands": (2, 0),
5771 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5772 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005773 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005774 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005775 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005776 "bitwise_xor": {
5777 "op": Op.BITWISE_XOR,
5778 "operands": (2, 0),
5779 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5780 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005781 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005782 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005783 },
Matthew Haddon459443c2021-08-23 16:43:13 +01005784 "intdiv": {
5785 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005786 "operands": (2, 0),
5787 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5788 "types": [DType.INT32],
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005789 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005790 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005791 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005792 "logical_and": {
5793 "op": Op.LOGICAL_AND,
5794 "operands": (2, 0),
5795 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5796 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005797 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005798 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005799 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005800 "logical_left_shift": {
5801 "op": Op.LOGICAL_LEFT_SHIFT,
5802 "operands": (2, 0),
5803 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5804 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005805 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005806 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005807 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005808 "logical_right_shift": {
5809 "op": Op.LOGICAL_RIGHT_SHIFT,
5810 "operands": (2, 0),
5811 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5812 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005813 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005814 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005815 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005816 "logical_or": {
5817 "op": Op.LOGICAL_OR,
5818 "operands": (2, 0),
5819 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5820 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005821 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005822 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005823 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005824 "logical_xor": {
5825 "op": Op.LOGICAL_XOR,
5826 "operands": (2, 0),
5827 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5828 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005829 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005830 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005831 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005832 "maximum": {
5833 "op": Op.MAXIMUM,
5834 "operands": (2, 0),
5835 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5836 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005837 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005838 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005839 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005840 "minimum": {
5841 "op": Op.MINIMUM,
5842 "operands": (2, 0),
5843 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5844 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005845 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005846 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005847 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005848 "mul": {
5849 "op": Op.MUL,
5850 "operands": (2, 0),
5851 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
5852 "types": TYPE_INT_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005853 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005854 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evRankMismatch, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005855 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005856 "pow": {
5857 "op": Op.POW,
5858 "operands": (2, 0),
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005859 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08005860 "types": TYPE_FP,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005861 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005862 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005863 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005864 "sub": {
5865 "op": Op.SUB,
5866 "operands": (2, 0),
5867 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5868 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005869 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005870 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005871 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005872 "table": {
5873 "op": Op.TABLE,
5874 # Use the automatic generation functions to create the input array
5875 # but create the table tensor in the build function, as it may be
5876 # a different type from the input
5877 "operands": (1, 0),
Kevin Chengfe392ce2021-10-18 21:51:55 +00005878 "build_fcn": (build_table, TosaTensorGen.tgBasic, TosaArgGen.agTable),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005879 "types": [DType.INT8, DType.INT16],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005880 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5881 TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005882 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005883 # Elementwise Unary operators
5884 "abs": {
5885 "op": Op.ABS,
5886 "operands": (1, 0),
5887 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5888 "types": TYPE_FI32,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005889 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5890 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005891 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005892 "bitwise_not": {
5893 "op": Op.BITWISE_NOT,
5894 "operands": (1, 0),
5895 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5896 "types": TYPE_INT,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005897 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5898 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005899 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005900 "ceil": {
5901 "op": Op.CEIL,
5902 "operands": (1, 0),
5903 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5904 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005905 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5906 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005907 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005908 "clz": {
5909 "op": Op.CLZ,
5910 "operands": (1, 0),
5911 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5912 "types": [DType.INT32],
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005913 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5914 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005915 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005916 "exp": {
5917 "op": Op.EXP,
5918 "operands": (1, 0),
5919 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5920 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005921 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5922 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005923 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005924 "floor": {
5925 "op": Op.FLOOR,
5926 "operands": (1, 0),
5927 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5928 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005929 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5930 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005931 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005932 "log": {
5933 "op": Op.LOG,
5934 "operands": (1, 0),
5935 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5936 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005937 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5938 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005939 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005940 "logical_not": {
5941 "op": Op.LOGICAL_NOT,
5942 "operands": (1, 0),
5943 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5944 "types": TYPE_BOOL,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005945 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5946 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005947 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005948 "negate": {
5949 "op": Op.NEGATE,
5950 "operands": (1, 0),
5951 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5952 "qgen": TosaQuantGen.qgUnary,
5953 "types": TYPE_INT_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005954 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
5955 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5956 TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005957 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005958 "reciprocal": {
5959 "op": Op.RECIPROCAL,
5960 "operands": (1, 0),
5961 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5962 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005963 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5964 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005965 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005966 "rsqrt": {
5967 "op": Op.RSQRT,
5968 "operands": (1, 0),
5969 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5970 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005971 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5972 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005973 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005974 # Elementwise Ternary operators
5975 "select": {
5976 "op": Op.SELECT,
5977 "operands": (3, 0),
5978 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
5979 "types": TYPE_FIB,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005980 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5981 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005982 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005983 # Comparison operators
5984 "equal": {
5985 "op": Op.EQUAL,
5986 "operands": (2, 0),
5987 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
5988 "types": TYPE_FI32,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005989 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5990 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005991 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005992 "greater_equal": {
5993 "op": Op.GREATER_EQUAL,
5994 "operands": (2, 0),
5995 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
5996 "types": TYPE_FI32,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005997 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5998 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005999 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006000 "greater": {
6001 "op": Op.GREATER,
6002 "operands": (2, 0),
6003 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
6004 "types": TYPE_FI32,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00006005 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
6006 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08006007 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006008 # Reduction operators
6009 "reduce_all": {
6010 "op": Op.REDUCE_ALL,
6011 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006012 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006013 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6014 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01006015 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
6016 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
6017 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08006018 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006019 "reduce_any": {
6020 "op": Op.REDUCE_ANY,
6021 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006022 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006023 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6024 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01006025 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
6026 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
6027 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08006028 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006029 "reduce_max": {
6030 "op": Op.REDUCE_MAX,
6031 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006032 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006033 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6034 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01006035 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
6036 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
6037 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08006038 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006039 "reduce_min": {
6040 "op": Op.REDUCE_MAX,
6041 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006042 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006043 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6044 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01006045 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
6046 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
6047 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08006048 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006049 "reduce_product": {
6050 "op": Op.REDUCE_PRODUCT,
6051 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006052 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006053 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6054 "types": TYPE_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01006055 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
6056 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
6057 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08006058 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006059 "reduce_sum": {
6060 "op": Op.REDUCE_SUM,
6061 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006062 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006063 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6064 "types": TYPE_FI32,
Matthew Haddond6ce7252021-09-29 15:35:44 +01006065 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
6066 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
6067 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08006068 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006069 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08006070 "concat": {
6071 "op": Op.CONCAT,
6072 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01006073 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006074 "types": TYPE_FIB,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006075 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evConcatInputRankMismatch,
Matthew Haddon01c359d2021-10-15 16:30:48 +01006076 TosaErrorValidator.evConcatShapeSumMismatch, TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType,
6077 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006078 },
6079 "pad": {
6080 "op": Op.PAD,
6081 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01006082 "rank": (1, 5),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006083 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
6084 "qgen": TosaQuantGen.qgPad,
6085 "types": TYPE_FIB,
Jeremy Johnson27cf5432021-11-16 11:12:17 +00006086 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evPadSmallerZero,
Matthew Haddone807aae2021-10-11 18:12:58 +01006087 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006088 },
6089 "reshape": {
6090 "op": Op.RESHAPE,
6091 "operands": (1, 0),
6092 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
6093 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01006094 "error_if_validators": (TosaErrorValidator.evTensorSizeInputOutputMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
6095 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006096 },
6097 "reverse": {
6098 "op": Op.REVERSE,
6099 "operands": (1, 0),
6100 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6101 "types": TYPE_FIB,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006102 "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evWrongInputType,
6103 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006104 },
6105 "slice": {
6106 "op": Op.SLICE,
6107 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01006108 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006109 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
6110 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01006111 "error_if_validators": (TosaErrorValidator.evStartSmallerZero, TosaErrorValidator.evSizeSmallerEqualZero, TosaErrorValidator.evStartSizeOutsideBounds,
6112 TosaErrorValidator.evSizeOutputShapeMismatch, TosaErrorValidator.evInputSizeStartLengthMismatch, TosaErrorValidator.evWrongRank,
6113 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006114 },
6115 "tile": {
6116 "op": Op.TILE,
6117 "operands": (1, 0),
6118 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
6119 "types": TYPE_FIB,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006120 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
6121 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006122 },
6123 "transpose": {
6124 "op": Op.TRANSPOSE,
6125 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01006126 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006127 "build_fcn": (
6128 build_transpose,
6129 TosaTensorGen.tgBasic,
6130 TosaArgGen.agTranspose,
6131 ),
6132 "types": TYPE_FIB,
Jeremy Johnson27cf5432021-11-16 11:12:17 +00006133 "error_if_validators": (TosaErrorValidator.evIndexOutsideBounds, TosaErrorValidator.evIndexUsedTwice,
Matthew Haddone807aae2021-10-11 18:12:58 +01006134 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006135 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006136 # Data nodes
6137 "const": {
6138 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07006139 "operands": (0, 1),
6140 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08006141 "types": TYPE_FIB,
6142 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006143 "identity": {
6144 "op": Op.IDENTITY,
6145 "operands": (1, 0),
6146 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6147 "types": TYPE_FIB,
6148 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006149 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08006150 "gather": {
6151 "op": Op.GATHER,
6152 # Only specify 'values' tensor here. 'indices' is generated in op building stage
6153 "operands": (1, 0),
6154 "rank": (3, 3),
6155 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
6156 "types": TYPE_INT_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006157 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006158 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evWrongRank)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006159 },
6160 "scatter": {
6161 "op": Op.SCATTER,
6162 # Only specify 'values_in' tensor here.
6163 #'indices' and 'input' are generated in op building stage
6164 "operands": (2, 0),
6165 "rank": (3, 3),
6166 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
6167 "types": TYPE_INT_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006168 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006169 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evWrongRank)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006170 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006171 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08006172 "resize": {
6173 "op": Op.RESIZE,
6174 "operands": (1, 0),
6175 "rank": (4, 4),
6176 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
6177 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01006178 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
6179 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
6180 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
Matthew Haddon848efb42021-09-09 12:30:53 +01006181 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01006182 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
6183 TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006184 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006185 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08006186 "cast": {
6187 "op": Op.CAST,
6188 "operands": (1, 0),
6189 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
6190 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006191 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
6192 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006193 },
6194 "rescale": {
6195 "op": Op.RESCALE,
6196 "operands": (1, 0),
Matthew Haddonc2025212021-10-08 21:21:05 +01006197 "rank": (1,4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006198 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01006199 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Matthew Haddonc2025212021-10-08 21:21:05 +01006200 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, TosaErrorValidator.evScaleTrue,
6201 TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
6202 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006203 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006204 # Custom
6205 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08006206 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07006207 # Two varients of cond_if, one that generates one of two constant tensors (no
6208 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
6209 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006210 "cond_if_const": {
6211 "op": Op.COND_IF,
6212 "operands": (0, 2),
6213 "build_fcn": (
6214 build_cond_if_const,
6215 TosaTensorGen.tgBasic,
6216 TosaArgGen.agCondIf,
6217 ),
6218 "types": [DType.BOOL],
Matthew Haddon630c17c2021-10-14 15:05:41 +01006219 "error_if_validators": (TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006220 },
6221 "cond_if_binary": {
6222 "op": Op.COND_IF,
6223 "operands": (2, 0),
6224 "build_fcn": (
6225 build_cond_if_binary,
6226 TosaTensorGen.tgBasic,
6227 TosaArgGen.agCondIf,
6228 ),
Les Bell6040b4d2021-10-11 12:50:31 +01006229 "types": TYPE_INT_FP,
Matthew Haddon630c17c2021-10-14 15:05:41 +01006230 "error_if_validators": (TosaErrorValidator.evInputListThenGraphMismatch, TosaErrorValidator.evInputListElseGraphMismatch,
6231 TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006232 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006233 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08006234 "while_loop": {
6235 "op": Op.WHILE_LOOP,
6236 "operands": (0, 1),
6237 "build_fcn": (
6238 build_while_loop,
6239 TosaTensorGen.tgBasic,
6240 TosaArgGen.agWhileLoop,
6241 ),
6242 "types": [DType.INT32],
Matthew Haddon630c17c2021-10-14 15:05:41 +01006243 "error_if_validators": (TosaErrorValidator.evInputListOutputListMismatch, TosaErrorValidator.evInputListCondGraphMismatch,
6244 TosaErrorValidator.evInputListBodyGraphInputMismatch, TosaErrorValidator.evInputListBodyGraphOutputMismatch,
6245 TosaErrorValidator.evCondGraphOutputNotMatchingBool)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006246 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006247 }
6248
Kevin Cheng550ccc52021-03-03 11:21:43 -08006249
Eric Kunzee5e26762020-10-13 16:11:07 -07006250class OutputShaper:
6251 # Methods in this class compute the expected output shape and datatype
6252 # for common classes of operations
6253 def __init__(self):
6254 pass
6255
6256 # These methods return arguments that can be used for
6257 # creating a new output tensor
6258 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01006259 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
6260 if error_name != ErrorIf.RankMismatch:
6261 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006262 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07006263
6264 shape = []
6265 for i in range(len(a.shape)):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01006266 if a.shape[i] == 1 and error_name == None:
Eric Kunzee5e26762020-10-13 16:11:07 -07006267 shape.append(b.shape[i])
6268 else:
6269 shape.append(a.shape[i])
6270
Matthew Haddoneacff9a2021-09-24 14:42:13 +01006271 if error_name == ErrorIf.WrongOutputType:
6272 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6273 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6274 outputDType = rng.choice(wrong_dtypes)
6275 else:
6276 outputDType = a.dtype
6277
6278 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006279
6280 @staticmethod
6281 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08006282 assert len(a.shape) == len(b.shape)
6283 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07006284
6285 shape = []
6286 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08006287 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07006288 shape.append(a.shape[i])
6289
Kevin Cheng550ccc52021-03-03 11:21:43 -08006290 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006291
6292 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01006293 def unaryOp(ser, rng, a, error_name=None):
6294 if error_name == ErrorIf.WrongOutputType:
6295 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6296 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6297 outputDType = rng.choice(wrong_dtypes)
6298 else:
6299 outputDType = a.dtype
6300
6301 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006302
6303 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006304 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00006305 if error_name != ErrorIf.RankMismatch:
6306 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006307 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07006308
6309 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00006310 for i in range(len(cond.shape)):
6311 if cond.shape[i] == 1 and error_name == None:
6312 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
6313 else:
6314 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07006315
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006316 if error_name == ErrorIf.WrongOutputType:
6317 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6318 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6319 outputDType = rng.choice(wrong_dtypes)
6320 else:
6321 outputDType = a.dtype
6322
6323 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006324
6325 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006326 def binaryComparisonOp(ser, rng, a, b , error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00006327 if error_name != ErrorIf.RankMismatch:
6328 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006329 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07006330
6331 # Do broadcast
6332 shape = []
6333 for i in range(len(a.shape)):
6334 if a.shape[i] == 1:
6335 shape.append(b.shape[i])
6336 else:
6337 shape.append(a.shape[i])
6338
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006339 if error_name == ErrorIf.WrongOutputType:
6340 wrong_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6341 outputDType = rng.choice(wrong_dtypes)
6342 else:
6343 outputDType = DType.BOOL
6344
6345 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006346
6347 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01006348 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006349 shape = a.shape.copy()
Matthew Haddond6ce7252021-09-29 15:35:44 +01006350 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ShapeOfAxisNotOne]:
6351 shape[axis] = 1
6352 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
6353 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07006354
Matthew Haddond6ce7252021-09-29 15:35:44 +01006355 if error_name == ErrorIf.WrongOutputType:
6356 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6357 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6358 outputDType = rng.choice(wrong_dtypes)
6359 else:
6360 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07006361
Matthew Haddond6ce7252021-09-29 15:35:44 +01006362 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006363
6364 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006365 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006366 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006367
6368 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
6369 del shape[axis]
6370
6371 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
6372 remove = rng.choice([True, False])
6373 if remove and len(shape) > 1:
6374 del shape[0]
6375 else:
6376 shape.append(1)
6377 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
6378 for i in range(len(shape)):
6379 shape[i] = shape[i] + rng.integers(1, 10)
6380
6381 if error_name == ErrorIf.WrongOutputType:
6382 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6383 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
6384 outputDType = rng.choice(wrong_dtypes)
6385 else:
6386 outputDType = DType.INT32
6387
6388 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006389
6390 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00006391 def conv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006392
6393 # IFM: NHWC
6394 # Filter: OHWI
6395 # OFM: NHWC
6396
6397 if len(padding) == 2:
6398 # Expand padding to 4 parameters in the case of transpose_conv2d
6399 # From H,W to T,B,L,R
6400 padding = [padding[0], padding[0], padding[1], padding[1]]
6401
Kevin Cheng550ccc52021-03-03 11:21:43 -08006402 h = (
6403 ifm.shape[1]
6404 - filter.shape[1]
6405 - (filter.shape[1] - 1) * (dilations[0] - 1)
6406 + padding[0]
6407 + padding[1]
6408 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07006409
Kevin Cheng550ccc52021-03-03 11:21:43 -08006410 w = (
6411 ifm.shape[2]
6412 - filter.shape[2]
6413 - (filter.shape[2] - 1) * (dilations[1] - 1)
6414 + padding[2]
6415 + padding[3]
6416 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07006417
Les Bell0e027d42021-11-09 14:42:14 +00006418 # Avoid illegal dimensions, which can be generated in error_if tests
6419 h = max(h, 1)
6420 w = max(w, 1)
6421
Eric Kunzee5e26762020-10-13 16:11:07 -07006422 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
6423
Kevin Cheng3a478572021-01-22 17:21:02 -08006424 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006425 out_dtype = DType.INT32
6426 elif ifm.dtype == DType.INT16:
6427 out_dtype = DType.INT48
6428 elif ifm.dtype == DType.FLOAT:
6429 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00006430 elif error_name == ErrorIf.WrongInputType:
6431 # Pick some potentially correct output dtype if input type is incorrect
6432 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006433 else:
Les Bell0e027d42021-11-09 14:42:14 +00006434 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
6435
6436 if error_name == ErrorIf.WrongOutputType:
6437 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
6438 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07006439
Kevin Cheng550ccc52021-03-03 11:21:43 -08006440 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006441
6442 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00006443 def conv3dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -07006444
6445 # IFM: NDHWC
6446 # Filter: ODHWI
6447 # OFM: NDHWC
6448
6449 d = (
6450 ifm.shape[1]
6451 - filter.shape[1]
6452 - (filter.shape[1] - 1) * (dilations[0] - 1)
6453 + padding[0]
6454 + padding[1]
6455 ) // strides[0] + 1
6456
6457 h = (
6458 ifm.shape[2]
6459 - filter.shape[2]
6460 - (filter.shape[2] - 1) * (dilations[1] - 1)
6461 + padding[2]
6462 + padding[3]
6463 ) // strides[1] + 1
6464
6465 w = (
6466 ifm.shape[3]
6467 - filter.shape[3]
6468 - (filter.shape[3] - 1) * (dilations[2] - 1)
6469 + padding[4]
6470 + padding[5]
6471 ) // strides[2] + 1
6472
Les Bell0e027d42021-11-09 14:42:14 +00006473 # Avoid illegal dimensions, which can be generated in error_if tests
6474 d = max(d, 1)
6475 h = max(h, 1)
6476 w = max(w, 1)
6477
Kevin Cheng1533b852021-09-01 12:51:58 -07006478 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
6479
6480 if ifm.dtype == DType.INT8:
6481 out_dtype = DType.INT32
6482 elif ifm.dtype == DType.INT16:
6483 out_dtype = DType.INT48
6484 elif ifm.dtype == DType.FLOAT:
6485 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00006486 elif error_name == ErrorIf.WrongInputType:
6487 # Pick some potentially correct output dtype if input type is incorrect
6488 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07006489 else:
Les Bell0e027d42021-11-09 14:42:14 +00006490 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
6491
6492 if error_name == ErrorIf.WrongOutputType:
6493 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
6494 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07006495
6496 return ser.addOutput(ofm_shape, out_dtype)
6497
6498 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00006499 def depthwiseConv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006500 # IFM: NHWC
6501 # Filter: HWCM
6502 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08006503 h = (
6504 ifm.shape[1]
6505 - filter.shape[0]
6506 - (filter.shape[0] - 1) * (dilations[0] - 1)
6507 + padding[0]
6508 + padding[1]
6509 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07006510
Kevin Cheng550ccc52021-03-03 11:21:43 -08006511 w = (
6512 ifm.shape[2]
6513 - filter.shape[1]
6514 - (filter.shape[1] - 1) * (dilations[1] - 1)
6515 + padding[2]
6516 + padding[3]
6517 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07006518
Les Bell0e027d42021-11-09 14:42:14 +00006519 # Avoid illegal dimensions, which can be generated in error_if tests
6520 h = max(h, 1)
6521 w = max(w, 1)
6522
Eric Kunzee5e26762020-10-13 16:11:07 -07006523 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
6524
Kevin Cheng3a478572021-01-22 17:21:02 -08006525 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006526 out_dtype = DType.INT32
6527 elif ifm.dtype == DType.INT16:
6528 out_dtype = DType.INT48
6529 elif ifm.dtype == DType.FLOAT:
6530 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00006531 elif error_name == ErrorIf.WrongInputType:
6532 # Pick some potentially correct output dtype if input type is incorrect
6533 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006534 else:
Les Bell0e027d42021-11-09 14:42:14 +00006535 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
6536
6537 if error_name == ErrorIf.WrongOutputType:
6538 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
6539 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07006540
Kevin Cheng550ccc52021-03-03 11:21:43 -08006541 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006542
6543 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01006544 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006545 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01006546 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006547 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01006548 h = 1
6549 w = 1
6550 else:
6551 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
6552 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
6553
6554 if error_name == ErrorIf.PoolingOutputShapeMismatch:
6555 choices = [1, 2, 3, 4, 5]
6556 h = h + rng.choice(choices)
6557 w = w + rng.choice(choices)
Eric Kunzee5e26762020-10-13 16:11:07 -07006558
Eric Kunzee5e26762020-10-13 16:11:07 -07006559 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01006560
6561 if error_name == ErrorIf.WrongOutputType:
6562 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6563 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
6564 outputDType = rng.choice(wrong_dtypes)
6565 else:
6566 outputDType = ifm.dtype
6567
6568 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006569
6570 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006571 def fullyConnectedOp(ser, rng, input, filter, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006572 # input: N, IC
6573 # filter: OC, IC
6574 # output: N, OC
6575
6576 output_shape = [input.shape[0], filter.shape[0]]
6577
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006578 if error_name == ErrorIf.WrongOutputType:
6579 if input.dtype == DType.INT8:
6580 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
6581 elif input.dtype == DType.INT16:
6582 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
6583 elif input.dtype == DType.FLOAT:
6584 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
6585 out_dtype = rng.choice(a=incorrect_types)
6586 elif input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006587 out_dtype = DType.INT32
6588 elif input.dtype == DType.INT16:
6589 out_dtype = DType.INT48
6590 elif input.dtype == DType.FLOAT:
6591 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006592 elif error_name == ErrorIf.WrongInputType:
6593 # Pick some potentially correct output dtype if input type is incorrect
6594 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006595 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08006596 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07006597
Kevin Cheng550ccc52021-03-03 11:21:43 -08006598 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006599
6600 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006601 def matmulOp(ser, rng, a, b, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07006602 # a: N, H, C
6603 # b: N, C, W
6604 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07006605
Kevin Cheng2d60f002021-06-09 14:18:32 -07006606 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07006607
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006608 if error_name == ErrorIf.WrongOutputType:
6609 if a.dtype == DType.INT8:
6610 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
6611 elif a.dtype == DType.INT16:
6612 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
6613 elif a.dtype == DType.FLOAT:
6614 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
6615 out_dtype = rng.choice(a=incorrect_types)
6616 elif a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006617 out_dtype = DType.INT32
6618 elif a.dtype == DType.INT16:
6619 out_dtype = DType.INT48
6620 elif a.dtype == DType.FLOAT:
6621 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006622 elif error_name == ErrorIf.WrongInputType:
6623 # Pick some potentially correct output dtype if input type is incorrect
6624 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006625 else:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006626 raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07006627
Kevin Cheng550ccc52021-03-03 11:21:43 -08006628 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006629
6630 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006631 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01006632 input1 = a[0]
6633 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07006634
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006635 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01006636 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006637 if not (
6638 # unable to concat tensors of different ranks
6639 error_name == ErrorIf.ConcatInputRankMismatch
6640 # unable to concat tensors along an invalid axis
6641 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006642 ):
6643 for tensor in remaining_inputs:
6644 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07006645
Matthew Haddon01c359d2021-10-15 16:30:48 +01006646 if error_name == ErrorIf.ConcatShapeSumMismatch:
6647 output_shape[axis] += rng.integers(5, 10)
6648
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006649 if error_name == ErrorIf.WrongOutputType:
6650 all_dtypes = {DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
6651 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
6652 outputDType = rng.choice(wrong_dtypes)
6653 else:
6654 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01006655
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006656 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006657
6658 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01006659 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006660
6661 output_shape = a.shape.copy()
6662
6663 for i in range(len(output_shape)):
6664 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
6665
Matthew Haddone807aae2021-10-11 18:12:58 +01006666 # Fix negative output shape if error_if test causes it
6667 if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
6668 output_shape = [i if i >= 1 else 1 for i in output_shape]
6669
6670 if error_name == ErrorIf.WrongOutputType:
6671 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6672 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6673 outputDType = rng.choice(wrong_dtypes)
6674 else:
6675 outputDType = a.dtype
6676
6677 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006678
6679 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01006680 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006681 output_shape = shape.copy()
6682
6683 totalElements = 1
6684 for i in a.shape:
6685 totalElements *= i
6686
6687 # If there are any -1 elements, figure out what that dimension must be
6688 totalOutputElements = 1
6689 for i in output_shape:
6690 if i != -1:
6691 totalOutputElements *= i
6692
6693 # And fill it in
6694 for i in range(len(output_shape)):
6695 if output_shape[i] == -1:
6696 output_shape[i] = totalElements // totalOutputElements
6697
Matthew Haddone807aae2021-10-11 18:12:58 +01006698 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
6699 for i in range(len(output_shape)):
6700 output_shape[i] = output_shape[i] + rng.integers(1, 10)
6701
6702 if error_name == ErrorIf.WrongOutputType:
6703 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6704 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6705 outputDType = rng.choice(wrong_dtypes)
6706 else:
6707 outputDType = a.dtype
6708
6709 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006710
6711 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01006712 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006713
Matthew Haddone807aae2021-10-11 18:12:58 +01006714 if error_name == ErrorIf.WrongOutputType:
6715 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6716 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6717 outputDType = rng.choice(wrong_dtypes)
6718 else:
6719 outputDType = a.dtype
6720
6721 if error_name == ErrorIf.SizeOutputShapeMismatch:
6722 output_shape = size.copy()
6723 for index in range(len(output_shape)):
6724 if output_shape[index] <= 2:
6725 output_shape[index] = output_shape[index] + rng.choice([1, 2])
6726 else:
6727 output_shape[index] = output_shape[index] + rng.choice([-2, -1, 1, 2])
6728 else:
6729 output_shape = size.copy()
6730
6731 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006732
6733 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006734 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006735
6736 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08006737 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07006738
6739 for i in range(len(output_shape)):
6740 output_shape[i] = a.shape[i] * multiples[i]
6741
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006742 if error_name == ErrorIf.WrongOutputType:
6743 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6744 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6745 outputDType = rng.choice(wrong_dtypes)
6746 else:
6747 outputDType = a.dtype
6748
6749 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006750
6751 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01006752 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006753 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01006754
Kevin Cheng550ccc52021-03-03 11:21:43 -08006755 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07006756
Matthew Haddone807aae2021-10-11 18:12:58 +01006757 if error_name == ErrorIf.IndexOutsideBounds:
6758 for i in range(len(output_shape)):
6759 output_shape[i] = a.shape[0]
6760 else:
6761 for i in range(len(output_shape)):
6762 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07006763
Matthew Haddone807aae2021-10-11 18:12:58 +01006764 if error_name == ErrorIf.WrongOutputType:
6765 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6766 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6767 outputDType = rng.choice(wrong_dtypes)
6768 else:
6769 outputDType = a.dtype
6770
6771 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006772
6773 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006774 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006775 if error_name != ErrorIf.WrongRank:
6776 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08006777 assert len(indices.shape) == 2
6778 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07006779
Kevin Cheng77d0f762020-11-24 10:26:32 -08006780 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
6781
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006782 if error_name == ErrorIf.WrongOutputType:
6783 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6784 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
6785 outputDType = rng.choice(wrong_dtypes)
6786 else:
6787 outputDType = values.dtype
6788
6789 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08006790
6791 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006792 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006793 if error_name != ErrorIf.WrongRank:
6794 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08006795 assert len(indices.shape) == 2
6796 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08006797 assert values_in.shape[0] == indices.shape[0] # N
6798 assert input.shape[1] == indices.shape[1] # W
6799 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08006800
6801 output_shape = values_in.shape
6802
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006803 if error_name == ErrorIf.WrongOutputType:
6804 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6805 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
6806 outputDType = rng.choice(wrong_dtypes)
6807 else:
6808 outputDType = values_in.dtype
6809
6810 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006811
6812 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006813 def tableOp(ser, rng, input, error_name=None):
6814 # Same shape as the input, dtype dependent on input dtype
6815 if error_name != ErrorIf.WrongInputType:
6816 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00006817 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006818 if error_name == ErrorIf.WrongOutputType:
6819 wrong_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6820 wrong_dtypes.remove(output_dtype)
6821 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01006822 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006823
6824 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08006825 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01006826 serializer,
6827 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08006828 input,
6829 mode,
6830 stride,
6831 offset,
6832 shift,
6833 stride_fp,
6834 offset_fp,
6835 output_dims,
6836 input_dtype,
6837 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01006838 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08006839 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01006840 if error_name == ErrorIf.WrongRank:
6841 output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
6842 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01006843 if error_name == ErrorIf.BatchMismatch:
6844 output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
6845 elif error_name == ErrorIf.ChannelMismatch:
6846 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
6847 else:
6848 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07006849
Matthew Haddon693ba9e2021-09-22 11:24:37 +01006850 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006851
6852 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006853 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08006854 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006855
6856 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00006857 def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None):
Kevin Cheng3a478572021-01-22 17:21:02 -08006858 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006859 out_dtype = DType.INT32
6860 elif ifm.dtype == DType.INT16:
6861 out_dtype = DType.INT48
6862 elif ifm.dtype == DType.FLOAT:
6863 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00006864 elif error_name == ErrorIf.WrongInputType:
6865 # Pick some potentially correct output dtype if input type is incorrect
6866 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006867 else:
Les Bell0e027d42021-11-09 14:42:14 +00006868 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
6869
6870 if error_name == ErrorIf.WrongOutputType:
6871 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
6872 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07006873
Kevin Cheng550ccc52021-03-03 11:21:43 -08006874 return ser.addOutput(output_shape, out_dtype)