blob: 1bd1b5a9f4d73f5526b4c12b42d6a9d57843b08b [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
Eric Kunzee5e26762020-10-13 16:11:07 -070017import numpy as np
18import argparse
19import sys
20import re
21import os
22import subprocess
23import shlex
24import json
25import glob
26import math
27import queue
28import threading
29import traceback
30import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010031import itertools
Matthew Haddon630c17c2021-10-14 15:05:41 +010032from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
Les Bell0e027d42021-11-09 14:42:14 +000048from tosa.DType import DType
49from tosa.Op import Op
50from tosa.ResizeMode import ResizeMode
Eric Kunzee5e26762020-10-13 16:11:07 -070051
Matthew Haddon630c17c2021-10-14 15:05:41 +010052
Les Bell0e027d42021-11-09 14:42:14 +000053def valueToName(item, value):
54 """Get the name of an attribute with the given value.
55
56 This convenience function is needed to print meaningful names for
57 the values of the tosa.Op.Op and tosa.DType.DType classes.
58 This would not be necessary if they were subclasses of Enum, or
59 IntEnum, which, sadly, they are not.
60
61 Args:
62 item: The class, or object, to find the value in
63 value: The value to find
64
65 Example, to get the name of a DType value:
66
67 name = valueToName(DType, DType.INT8) # returns 'INT8'
68 name = valueToName(DType, 4) # returns 'INT8'
69
70 Returns:
71 The name of the first attribute found with a matching value,
72
73 Raises:
74 ValueError if the value is not found
75 """
76 for attr in dir(item):
77 if getattr(item, attr) == value:
78 return attr
79 raise ValueError(f'value ({value}) not found')
80
81def allDTypes(*, excludes=None):
82 """Get a set of all DType values, optionally excluding some values.
83
84 This convenience function is needed to provide a sequence of DType values.
85 This would be much easier if DType was a subclass of Enum, or IntEnum,
86 as we could then iterate over the values directly, instead of using
87 dir() to find the attributes and then check if they are what we want.
88
89 Args:
90 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
91
92 Returns:
93 A set of DType values
94 """
95 excludes = () if not excludes else excludes
96 return {getattr(DType, t) for t in dir(DType)
97 if not callable(getattr(DType, t)) and not t.startswith('__')
98 and getattr(DType, t) not in excludes}
99
100def usableDTypes(*, excludes=None):
101 """Get a set of usable DType values, optionally excluding some values.
102
103 Excludes (DType.UNKNOWN, DType.UINT8) in addition to the excludes
104 specified by the caller, as the serializer lib does not support them.
105 If you wish to include 'UNKNOWN' or 'UINT8' use allDTypes instead.
106
107 Args:
108 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
109
110 Returns:
111 A set of DType values
112 """
113 omit = {DType.UNKNOWN, DType.UINT8}
114 omit.update(excludes if excludes else ())
115 return allDTypes(excludes=omit)
116
Matthew Haddon630c17c2021-10-14 15:05:41 +0100117def product(shape):
118 value = 1
119 for n in shape:
120 value *= n
121 return value
122
Les Bell0e027d42021-11-09 14:42:14 +0000123
Eric Kunzee5e26762020-10-13 16:11:07 -0700124class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800125 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
126
Eric Kunzee5e26762020-10-13 16:11:07 -0700127 def __init__(self):
128 pass
129
130 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100131 def getQinfo(testGen, dtype, error_name=None):
132
Les Bell30e46802021-07-23 09:43:31 +0100133 if dtype == DType.INT8:
134 return testGen.randInt(-128, 128)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100135 elif dtype == DType.UINT8:
Les Bell30e46802021-07-23 09:43:31 +0100136 return testGen.randInt(0, 256)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100137 elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.WeightZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100138 zero_point = testGen.randInt(-128, 128)
139 if zero_point == 0:
140 zero_point = 1
141 return zero_point
Les Bell30e46802021-07-23 09:43:31 +0100142 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700143
144 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100145 def qgUnary(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700146 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100147 if error_name == ErrorIf.InputZeroPointNotZero:
148 qinfo.UnaryQuantInfo(
149 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype)
150 )
151 elif error_name == ErrorIf.OutputZeroPointNotZero:
152 qinfo.UnaryQuantInfo(
153 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype, error_name)
154 )
155 else:
156 qinfo.UnaryQuantInfo(
157 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
158 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700159 return qinfo
160
161 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100162 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700163 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +0100164 if isinstance(dtype_or_dtypeList, list):
165 # a list of [input, weights, accumulator] dtypes
166 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -0700167 else:
Les Bell30e46802021-07-23 09:43:31 +0100168 # an int, [input, weights, accumulator] dtypes are the same
169 dtypeList = [dtype_or_dtypeList] * 3
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100170
171 if error_name == ErrorIf.InputZeroPointNotZero:
172 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
173 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
174 elif error_name == ErrorIf.WeightZeroPointNotZero:
175 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
176 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
177 else:
178 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
179 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
180
Les Bell30e46802021-07-23 09:43:31 +0100181 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -0700182 return qinfo
183
184 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100185 def qgMatmul(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700186 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100187 if error_name == ErrorIf.InputZeroPointNotZero:
188 qinfo.MatMulQuantInfo(
189 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype, error_name)
Kevin Chengacb550f2021-06-29 15:32:19 -0700190 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100191 else:
192 qinfo.MatMulQuantInfo(
193 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
194 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700195 return qinfo
196
197 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100198 def qgPad(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700199 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone807aae2021-10-11 18:12:58 +0100200 if error_name == ErrorIf.InputZeroPointNotZero:
201 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name))
202 else:
203 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700204 return qinfo
205
206 @staticmethod
207 def computeMultiplierAndShift(scaleFp, scale32):
208 # Derived from computeMultiplierAndShiftTosaScale32
209 # Provide a floating-point scaling factor and the scale32 parameter
210 # to compute the multiplier and shift
211
212 if scale32:
213 scaleBits = 31
214 else:
215 scaleBits = 15
216
217 m, shift = math.frexp(scaleFp)
218
219 if scaleFp < 0.0:
220 m = -m
221
222 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800223 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700224
225 if multiplier == (1 << scaleBits):
226 multiplier = multiplier // 2
227 shift = shift + 1
228
229 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100230 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
231
232 # Adjust multiplier such that shift is in allowed value range.
233 if shift == 0:
234 multiplier = multiplier // 4
235 shift = shift + 2
236 elif shift == 1:
237 multiplier = multiplier // 2
238 shift = shift + 1
239 elif shift == 63:
240 multiplier = multiplier * 2
241 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700242
Kevin Cheng550ccc52021-03-03 11:21:43 -0800243 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100244 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700245
246 return multiplier, shift
247
248
Kevin Cheng550ccc52021-03-03 11:21:43 -0800249class TosaTensorGen:
250 """Tensor generators create a shape list for the placeholder and const tensor
251 data operands for the operator. The actual random data is generated separately for each test."""
252
Eric Kunzee5e26762020-10-13 16:11:07 -0700253 def __init__(self):
254 pass
255
256 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100257 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800258 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700259 shape = testGen.makeShape(rank)
260
Matthew Haddon630c17c2021-10-14 15:05:41 +0100261 # Constrict the overall size of the shape when creating ERROR_IF tests
262 if error_name:
263 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
Matthew Haddonc2025212021-10-08 21:21:05 +0100264
Eric Kunzee5e26762020-10-13 16:11:07 -0700265 shape_list = []
266 for i in range(pl + const):
267 shape_list.append(shape.copy())
268
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100269 if error_name == ErrorIf.RankMismatch:
270 if rank == 1 and i != 1:
271 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
272 elif i != 1:
273 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
274
Eric Kunzee5e26762020-10-13 16:11:07 -0700275 return shape_list
276
277 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100278 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800279 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700280
Matthew Haddon848efb42021-09-09 12:30:53 +0100281 if error_name != ErrorIf.WrongRank:
282 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700283
284 shape = testGen.makeShape(rank)
285
286 # Constrict the batch size?
287 if testGen.args.max_batch_size:
288 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100289
Matthew Haddon630c17c2021-10-14 15:05:41 +0100290 # Constrict the overall size of the shape when creating ERROR_IF tests
291 if error_name:
292 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
Eric Kunzee5e26762020-10-13 16:11:07 -0700293
294 shape_list = []
295 for i in range(pl + const):
296 shape_list.append(shape.copy())
297
298 return shape_list
299
300 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100301 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800302 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800303
Kevin Cheng550ccc52021-03-03 11:21:43 -0800304 assert pl == 2
305 assert const == 0
306 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800307
308 values_in_shape = testGen.makeShape(rank)
309
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100310 # ignore max batch size if target shape is set
311 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800312 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
313
Kevin Cheng550ccc52021-03-03 11:21:43 -0800314 W = testGen.randInt(
315 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
316 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100317 # Constrict W if one dimension is too large to keep tensor size reasonable
318 if max(values_in_shape) > 5000:
319 W = testGen.randInt(0, 16)
320
Kevin Cheng77d0f762020-11-24 10:26:32 -0800321 input_shape = [values_in_shape[0], W, values_in_shape[2]]
322
323 shape_list = []
324 shape_list.append(values_in_shape.copy())
325 shape_list.append(input_shape.copy())
326
327 return shape_list
328
329 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100330 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700331 shape = testGen.makeShape(rank)
332
Kevin Cheng550ccc52021-03-03 11:21:43 -0800333 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700334
335 shape_list = []
336
337 # Choose one of the inputs to broadcast
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +0000338 # Note: Simplifies OutputShaper code if we don't change first shape for errors
339 bcast_idx = testGen.randInt(0 if error_name == None else 1, pl + const)
Eric Kunzee5e26762020-10-13 16:11:07 -0700340 for i in range(pl + const):
341 shape_bcast = shape.copy()
342
343 # If the chosen input, pick a random index to broadcast
344 if i == bcast_idx:
345 fuzz_idx = testGen.randInt(0, rank)
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +0000346 if error_name == ErrorIf.DimensionMismatch:
347 shape_bcast[fuzz_idx] += 1
348 elif error_name == ErrorIf.RankMismatch:
349 # Add one rank to the shape (or more for rank of 1)
350 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
351 shape_bcast = np.concatenate((shape_bcast, testGen.makeShape(extra_ranks)))
352 if rank != 1:
353 # Either keep the extra rank, or remove it
354 new_len = testGen.rng.choice([-2, len(shape_bcast)])
355 shape_bcast = shape_bcast[:new_len]
356 else:
357 shape_bcast[fuzz_idx] = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700358
359 shape_list.append(shape_bcast)
360
361 return shape_list
362
363 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100364 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800365 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700366
Les Bell0e027d42021-11-09 14:42:14 +0000367 if error_name != ErrorIf.WrongRank:
368 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700369
370 # IFM dimensions are NHWC
371 ifm_shape = testGen.makeShape(rank)
372
373 # Constrict the batch size?
374 if testGen.args.max_batch_size:
375 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
376
Les Bell0e027d42021-11-09 14:42:14 +0000377 # Constrict the overall size of the shape when creating ERROR_IF tests
378 if error_name:
379 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
380
Eric Kunzee5e26762020-10-13 16:11:07 -0700381 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800382 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700383
384 # Generate a random OFM depth
385 ofm_depth = testGen.makeShape(1)[0]
386
387 # The filter dimensions are OHWI
388 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
389
390 # The bias is OC
391 bias_shape = np.asarray([ofm_depth])
392
393 return [ifm_shape, filter_shape, bias_shape]
394
395 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100396 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700397 pl, const = op["operands"]
398
Les Bell0e027d42021-11-09 14:42:14 +0000399 if error_name != ErrorIf.WrongRank:
400 assert rank == 5
Kevin Cheng1533b852021-09-01 12:51:58 -0700401
402 # IFM dimensions are NDHWC
403 ifm_shape = testGen.makeShape(rank)
404
405 # Constrict the batch size?
406 if testGen.args.max_batch_size:
407 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
408
Les Bell0e027d42021-11-09 14:42:14 +0000409 # Constrict the overall size of the shape when creating ERROR_IF tests
410 if error_name:
411 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
412
Kevin Cheng1533b852021-09-01 12:51:58 -0700413 # Get the filter depth/height/width from the operator parameters
414 filter_dhw = op["filter"]
415
416 # Generate a random OFM channel
417 ofm_channel = testGen.makeShape(1)[0]
418
419 # The filter dimensions are ODHWI
420 filter_shape = np.asarray(
421 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
422 )
423
424 # The bias is OC
425 bias_shape = np.asarray([ofm_channel])
426
427 return [ifm_shape, filter_shape, bias_shape]
428
429 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100430 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800431 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700432
Les Bell0e027d42021-11-09 14:42:14 +0000433 if error_name != ErrorIf.WrongRank:
434 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700435
436 # IFM dimensions are NHWC
437 ifm_shape = testGen.makeShape(rank)
438
439 # Constrict the batch size?
440 if testGen.args.max_batch_size:
441 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
442
Les Bell0e027d42021-11-09 14:42:14 +0000443 # Constrict the overall size of the shape when creating ERROR_IF tests
444 if error_name:
445 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
446
Eric Kunzee5e26762020-10-13 16:11:07 -0700447 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800448 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700449
450 # Generate a random OFM depth
451 ofm_depth = testGen.makeShape(1)[0]
452
453 # The filter dimensions are OHWI
454 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
455
Kevin Cheng989cb052021-04-28 16:29:44 -0700456 # The bias is OC
457 bias_shape = np.asarray([ofm_depth])
458
459 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700460
461 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100462 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800463 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700464
Les Bell0e027d42021-11-09 14:42:14 +0000465 if error_name != ErrorIf.WrongRank:
466 assert rank == 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800467 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700468
469 # IFM dimensions are NHWC
470 ifm_shape = testGen.makeShape(rank)
471
472 # Constrict the batch size?
473 if testGen.args.max_batch_size:
474 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
475
Les Bell0e027d42021-11-09 14:42:14 +0000476 # Constrict the overall size of the shape when creating ERROR_IF tests
477 if error_name:
478 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
479
Eric Kunzee5e26762020-10-13 16:11:07 -0700480 # Get the filter height/width from the operator parameters
481 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800482 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700483
484 # Generate a random OFM depth, but don't let it get too big because
485 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800486 filter_m = (
487 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
488 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700489
490 # The filter dimensions are HWCM
491 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
492
493 # The bias is M * C
494 bias_shape = np.asarray([ifm_shape[3] * filter_m])
495
496 return [ifm_shape, filter_shape, bias_shape]
497
498 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100499 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800500 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700501
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100502 if error_name != ErrorIf.WrongRank:
503 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700504
505 input_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100506
Matthew Haddon630c17c2021-10-14 15:05:41 +0100507 # Constrict the overall size of the shape when creating ERROR_IF tests
508 if error_name:
Les Bell0e027d42021-11-09 14:42:14 +0000509 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100510
Kevin Chengacb550f2021-06-29 15:32:19 -0700511 filter_oc = testGen.rng.integers(
512 low=testGen.args.tensor_shape_range[0],
513 high=testGen.args.tensor_shape_range[1],
514 size=1,
515 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700516 filter_shape = np.asarray([filter_oc, input_shape[1]])
517
518 bias_shape = np.asarray([filter_oc])
519
520 return [input_shape, filter_shape, bias_shape]
521
522 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100523 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800524 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700525
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100526 if error_name != ErrorIf.WrongRank:
527 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800528 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700529
530 a_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100531
Matthew Haddon630c17c2021-10-14 15:05:41 +0100532 # Constrict the overall size of the shape when creating ERROR_IF tests
533 if error_name:
Les Bell0e027d42021-11-09 14:42:14 +0000534 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100535
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100536 # Get a random number for b_oc even if target shape is defined
537 b_oc = np.int32(
538 testGen.rng.integers(
539 low=testGen.args.tensor_shape_range[0],
540 high=testGen.args.tensor_shape_range[1],
541 size=1,
542 )
543 )[0]
544 # If N or H is large let b_oc be 1 to reduce output tensor size
545 if max(a_shape) > 1000:
546 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700547
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100548 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700549 return [a_shape, b_shape]
550
Matthew Haddon818ab902021-07-27 09:12:49 +0100551 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100552 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100553 pl, const = opName["operands"]
554 shape = testGen.makeShape(rank)
555
556 # Create extra tensors to concat.
557 # Take into account value of pl when getting maximum number of concats
558 num_tensors = testGen.randInt(0, 4)
559 shape_list = []
560 for i in range(pl + const + num_tensors):
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100561 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
562 remove = testGen.rng.choice([True, False])
563 wrongShape = shape.copy()
564
565 if remove and len(shape) > 1:
566 wrongShape = wrongShape[1:]
567 else:
568 wrongShape = list(wrongShape)
569 wrongShape.append(testGen.rng.integers(1, 10))
570
571 shape_list.append(wrongShape)
572 else:
573 shape_list.append(shape.copy())
Matthew Haddon818ab902021-07-27 09:12:49 +0100574
575 return shape_list
576
577 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100578 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100579 if error_name in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ConcatInputRankMismatch]:
580 return shapeList
581
Matthew Haddon818ab902021-07-27 09:12:49 +0100582 # Split concat shape along axis to allow for multiple const inputs
583 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100584 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100585 # If axis can't be split we still need to invalidate other dimensions
586 if error_name == ErrorIf.ConcatInputDimMismatch:
587 for shape in shapeList[1:]:
588 # Negative test shapeLists are created individually for each test,
589 # so no need to copy the shape before altering it.
590 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
Matthew Haddon818ab902021-07-27 09:12:49 +0100591 return shapeList
592
Jeremy Johnson960985a2021-10-06 10:58:14 +0100593 # Create copy of shape we are going to split (so we don't alter shapeList)
594 shape = shapeList[0].copy()
595 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100596 new_shapeList = [shape.copy()]
597 length_on_axis = shape[axis]
598 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700599 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100600 # Calculate split on axis and remaining value
601 split_shape_val = int(shape[axis] / 2)
602 remaining_length = remaining_length - split_shape_val
603
604 # Append new shape, and set remaining shape
605 shape[axis] = split_shape_val
606 new_shapeList.append(shape.copy())
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100607
608 # invalidate dimensions
609 if error_name == ErrorIf.ConcatInputDimMismatch:
610 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
611 else:
612 shape[axis] = remaining_length
613
Matthew Haddon818ab902021-07-27 09:12:49 +0100614 if i == len(shapeList) - 3:
615 new_shapeList.append(shape.copy())
616
617 return new_shapeList
618
619
Eric Kunzee5e26762020-10-13 16:11:07 -0700620class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800621 """Argument generators create exhaustive or random lists of attributes for operators that take
622 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
623 tuples where the descriptive_name is appended to the test name and the arglist is expanded
624 as arguments to the operator build function."""
625
Eric Kunzee5e26762020-10-13 16:11:07 -0700626 def __init__(self):
627 pass
628
629 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100630 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800631 """A trivial argument generator for operators that don't take any
632 non-tensor arguments"""
633 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700634
635 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100636 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800637 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700638 axes = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700639 shape = shapeList[0]
640
Matthew Haddond6ce7252021-09-29 15:35:44 +0100641 if error_name == ErrorIf.AxisSmallerZero:
642 small_axis = testGen.rng.integers(-5, 0)
643 axes.append(("axis{}".format(small_axis), [small_axis]))
644 elif error_name == ErrorIf.AxisLargerRank:
645 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
646 axes.append(("axis{}".format(large_axis), [large_axis]))
647 else:
648 for a in range(0, len(shape)):
649 axes.append(("axis{}".format(a), [a]))
650
Eric Kunzee5e26762020-10-13 16:11:07 -0700651 return axes
652
653 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100654 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700655 arg_list = []
656
657 ifm_shape = shapeList[0]
658 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100659 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
660 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700661
Les Bell7aa69f42021-09-20 10:44:07 +0100662 # Check the rank
663 rank = 5 if opName.startswith("conv3d") else 4
Les Bell0e027d42021-11-09 14:42:14 +0000664 if error_name != ErrorIf.WrongRank:
665 assert len(ifm_shape) == rank
666 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700667
Les Bell7aa69f42021-09-20 10:44:07 +0100668 # kernel rank omits batch and channels
669 k_rank = rank - 2
Les Bell0e027d42021-11-09 14:42:14 +0000670 assert len(k) == k_rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700671
Les Bell7aa69f42021-09-20 10:44:07 +0100672 # Generate comprehensive argument lists
Les Bell0e027d42021-11-09 14:42:14 +0000673 # - except for named errors, which use specific invalid value(s)
674 if error_name == ErrorIf.PadSmallerZero:
675 p_vals = [testGen.rng.choice(range(-5, 0))]
676 else:
677 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100678 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
Les Bell0e027d42021-11-09 14:42:14 +0000679 if error_name == ErrorIf.StrideSmallerOne:
680 # Can't use stride=0, as it is used to derive output shape, as a divisor
681 s_vals = [testGen.rng.choice(range(-5, 0))]
682 else:
683 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100684 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
Les Bell0e027d42021-11-09 14:42:14 +0000685 if error_name == ErrorIf.DilationSmallerOne:
686 d_vals = [testGen.rng.choice(range(-5, 1))]
687 else:
688 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100689 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700690
Les Bell0e027d42021-11-09 14:42:14 +0000691 if not error_name:
692 # add some oversize argument values
693 if max(ifm_shape) < 64:
694 bigPadding = 9
695 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
696 bigStride = 8
697 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
698 bigDilation = 7
699 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100700
Les Bell0e027d42021-11-09 14:42:14 +0000701 # There are too many parameter combinations, so generate them sparsely,
702 # very sparse for negative tests
703 sparsity_factor = 2 if error_name else 100
704 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
705 # If there are only a small number of tests, just select them all
Les Bell7aa69f42021-09-20 10:44:07 +0100706 if sparsity < 13:
707 sparsity = 1
Les Bell0e027d42021-11-09 14:42:14 +0000708 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
Les Bell7aa69f42021-09-20 10:44:07 +0100709 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
710 sparsity += 1
Les Bell0e027d42021-11-09 14:42:14 +0000711
Les Bellf414b3c2021-09-06 11:29:46 +0100712 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100713 for s in sorted(list(strides)):
714 for p in sorted(list(paddings)):
715 for d in sorted(list(dilations)):
716 if (n % sparsity == 0
717 # padding must not exceed the kernel size ?
718 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
719 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
720 # the padded shape must exceed the kernel size
721 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
722 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
723 # the padded shape must exceed the dilation
724 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
725 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
726 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100727 arg_list.append(
728 (
729 "st{}_pad{}_dilat{}".format(
730 "".join([str(x) for x in s]),
731 "".join([str(x) for x in p]),
732 "".join([str(x) for x in d]),
733 ),
734 [s, p, d],
735 )
736 )
737 n += 1
738
Kevin Cheng1533b852021-09-01 12:51:58 -0700739 return arg_list
740
741 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100742 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700743 arg_list = []
744
745 ifm_shape = shapeList[0]
746 filter_shape = shapeList[1]
747
748 # Must be rank 4
Les Bell0e027d42021-11-09 14:42:14 +0000749 if error_name != ErrorIf.WrongRank:
750 assert len(ifm_shape) == 4
751 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700752
Les Bell7aa69f42021-09-20 10:44:07 +0100753 # Generate comprehensive argument lists
Les Bell0e027d42021-11-09 14:42:14 +0000754 # - except for named errors, which use specific invalid value(s)
755 if error_name == ErrorIf.PadSmallerZero:
756 p_vals = [testGen.rng.choice(range(-5, 0))]
757 else:
758 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100759 paddings = {x for x in itertools.product(*([p_vals] * 2))}
Les Bell0e027d42021-11-09 14:42:14 +0000760 if error_name == ErrorIf.StrideSmallerOne:
761 # Can't use stride=0, as it is used to derive output shape, as a divisor
762 s_vals = [testGen.rng.choice(range(-5, 0))]
763 else:
764 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100765 strides = {x for x in itertools.product(*([s_vals] * 2))}
Les Bell0e027d42021-11-09 14:42:14 +0000766 if error_name == ErrorIf.DilationSmallerOne:
767 d_vals = [testGen.rng.choice(range(-5, 1))]
768 else:
769 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100770 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700771
Les Bell0e027d42021-11-09 14:42:14 +0000772 if not error_name:
773 # add some oversize argument values
774 if max(ifm_shape) < 64:
775 bigPadding = 9
776 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
777 bigStride = 8
778 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
779 bigDilation = 7
780 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700781
Les Bell0e027d42021-11-09 14:42:14 +0000782 # There are too many parameter combinations, so generate them sparsely,
783 # very sparse for negative tests
784 sparsity_factor = 2 if error_name else 100
785 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
786 # If there are only a small number of tests, just select them all
Les Bell7aa69f42021-09-20 10:44:07 +0100787 if sparsity < 13:
788 sparsity = 1
Les Bell0e027d42021-11-09 14:42:14 +0000789 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
Les Bell7aa69f42021-09-20 10:44:07 +0100790 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
791 sparsity += 1
Les Bell0e027d42021-11-09 14:42:14 +0000792
Les Bell7aa69f42021-09-20 10:44:07 +0100793 n = 0
794 for s in sorted(list(strides)):
795 for p in sorted(list(paddings)):
796 for d in sorted(list(dilations)):
797 if n % sparsity == 0:
798 # Determine the output shape
799 oh = (
800 ifm_shape[1]
801 - filter_shape[1]
802 - (filter_shape[1] - 1) * (d[0] - 1)
803 + 2 * p[0]
804 ) // s[0] + 1
805 ow = (
806 ifm_shape[2]
807 - filter_shape[2]
808 - (filter_shape[2] - 1) * (d[1] - 1)
809 + 2 * p[1]
810 ) // s[1] + 1
811 os = [ifm_shape[0], oh, ow, filter_shape[0]]
812 arg_list.append(
813 (
814 "st{}_pad{}_dilat{}_os{}".format(
815 "".join([str(x) for x in s]),
816 "".join([str(x) for x in p]),
817 "".join([str(x) for x in d]),
818 "x".join([str(x) for x in os]),
819 ),
820 [s, p, d, os],
821 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800822 )
Les Bell7aa69f42021-09-20 10:44:07 +0100823 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700824
825 return arg_list
826
827 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100828 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700829 arg_list = []
830 rank = len(shapeList[0])
831
Les Bell7ffccce2021-07-28 15:37:02 +0100832 # Exhaustively test combinations of padding on each side of each dimension
833 # - the range of padding values is defined by pad_min and pad_max
834 # - for padding >9, the name format needs to be more distinctive
835 pad_min, pad_max = 0, 1
836 pad_values = [x for x in range(pad_min, pad_max + 1)]
Matthew Haddone807aae2021-10-11 18:12:58 +0100837 if error_name == ErrorIf.PadSmallerZero:
838 pad_values = [x for x in range(-2, 0)]
Les Bell7ffccce2021-07-28 15:37:02 +0100839 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
840 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700841
Kevin Chengfe392ce2021-10-18 21:51:55 +0000842 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
843 pad_const_int = testGen.getRandNumberDType(dtype)
844 pad_const_fp = 0
845 elif dtype == DType.FLOAT:
846 pad_const_int = 0
847 pad_const_fp = testGen.getRandNumberDType(dtype)
848 else:
849 return []
850
Les Bell7ffccce2021-07-28 15:37:02 +0100851 for paddings in shape_pad_values:
852 name = "pad"
853 for r in range(rank):
854 before, after = paddings[r]
855 name = f"{name}{before}{after}"
Kevin Chengfe392ce2021-10-18 21:51:55 +0000856 arg_list.append((name, [np.array(paddings), pad_const_int, pad_const_fp]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700857
858 return arg_list
859
860 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100861 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700862 arg_list = []
863
864 shape = shapeList[0]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100865 if error_name != ErrorIf.WrongRank:
866 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700867
Les Bell7aa69f42021-09-20 10:44:07 +0100868 # Generate comprehensive argument lists
869 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
870 paddings = {x for x in itertools.product(*([p_vals] * 4))}
871 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
872 strides = {x for x in itertools.product(*([s_vals] * 2))}
873 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
874 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700875
Les Bell7aa69f42021-09-20 10:44:07 +0100876 # add some oversize argument values
877 bigStride = 7
878 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
879 bigKernel = 6
880 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
881 if max(shape) < 64:
882 # padding must be less than the kernel size
883 bigPadding = bigKernel - 1
884 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700885
Les Bell0e027d42021-11-09 14:42:14 +0000886 # There are too many parameter combinations, so generate them sparsely,
887 # very sparse for negative tests
888 sparsity_factor = 2 if error_name else 500
889 sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
890
Les Bell7aa69f42021-09-20 10:44:07 +0100891 n = 0
892 for s in sorted(list(strides)):
893 for p in sorted(list(paddings)):
894 for k in sorted(list(kernels)):
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100895 if error_name in [ErrorIf.StrideSmallerOne, ErrorIf.KernelSmallerOne, ErrorIf.PadSmallerZero, ErrorIf.PadLargerEqualKernel]:
896 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(testGen, error_name, s, p, k)
897 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
898 arg_list.append(
899 (
900 "st{}_kern{}_pad{}".format(
901 "".join([str(x) for x in sNew]),
902 "".join([str(x) for x in kNew]),
903 "".join([str(x) for x in pNew]),
904 ),
905 [sNew, pNew, kNew],
906 )
907 )
908 elif (n % sparsity == 0
Les Bell7aa69f42021-09-20 10:44:07 +0100909 # padding must not exceed the kernel size
910 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
911 # the padded shape must exceed the kernel size
912 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
913 ):
914 arg_list.append(
915 (
916 "st{}_kern{}_pad{}".format(
917 "".join([str(x) for x in s]),
918 "".join([str(x) for x in k]),
919 "".join([str(x) for x in p]),
920 ),
921 [s, p, k],
922 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800923 )
Les Bell7aa69f42021-09-20 10:44:07 +0100924 n += 1
925
Eric Kunzee5e26762020-10-13 16:11:07 -0700926 return arg_list
927
928 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100929 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700930 arg_list = []
931
932 # Enumerate the output types here
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100933 if error_name == ErrorIf.WrongOutputType:
934 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
935 elif inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800936 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700937 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800938 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700939 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800940 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700941 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800942 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700943 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800944 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100945 elif error_name == ErrorIf.WrongInputType:
946 # Pick some potentially correct output type for incorrect input type
947 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700948 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800949 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700950
951 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800952 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700953
954 return arg_list
955
956 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100957 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700958 arg_list = []
959
960 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100961 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100962 if dtype in [DType.UINT8, DType.INT8] and error_name == ErrorIf.OutputZeroPointNotZero:
963 continue
964 if inDtype == DType.UINT8 and dtype != DType.INT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100965 # The only output dtype for UINT8 is INT8, skip all other combinations
966 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100967 if inDtype != DType.INT8 and dtype == DType.UINT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100968 # The only input dtype for UINT8 is INT8, skip all other combinations
969 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100970 if error_name == ErrorIf.WrongOutputType and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype):
971 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100972
Kevin Cheng550ccc52021-03-03 11:21:43 -0800973 for scale32 in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100974 if error_name == ErrorIf.ScaleTrue and scale32 == False:
975 continue
976 elif error_name == ErrorIf.ScaleNotTrue and scale32 == True:
977 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800978 for double_round in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100979 if error_name == ErrorIf.ScaleNotTrue and double_round == False:
980 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800981 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700982
Matthew Haddonc2025212021-10-08 21:21:05 +0100983 if inDtype == DType.INT48 and scale32 and error_name != ErrorIf.ScaleTrue:
Eric Kunzee5e26762020-10-13 16:11:07 -0700984 # Illegal condition. Must be scale32=False
985 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100986 if double_round and not scale32 and error_name != ErrorIf.ScaleNotTrue:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100987 # Illegal condition. ERROR_IF(!scale32 && double_round)
988 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700989
Kevin Cheng550ccc52021-03-03 11:21:43 -0800990 arg_list.append(
991 (
992 "out{}_sc{}_dr{}_pc{}".format(
993 DTypeNames[dtype],
994 int(scale32),
995 int(double_round),
996 int(per_channel),
997 ),
998 [dtype, scale32, double_round, per_channel],
999 )
1000 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001001
1002 return arg_list
1003
Kevin Chengaee1fac2020-11-11 13:54:06 -08001004 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001005 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -08001006 arg_list = []
1007
1008 if dtype is DType.INT32:
1009 for p in range(testGen.args.num_rand_permutations):
1010
1011 shift = testGen.randInt(0, 32)
1012
Kevin Cheng550ccc52021-03-03 11:21:43 -08001013 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001014 else:
Matthew Haddon43e37192021-07-09 14:13:02 +01001015 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001016
1017 return arg_list
1018
1019 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001020 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -08001021 arg_list = []
1022
Kevin Cheng550ccc52021-03-03 11:21:43 -08001023 arg_list.append(("roundTrue", [True]))
1024 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001025
1026 return arg_list
1027
Eric Kunzee5e26762020-10-13 16:11:07 -07001028 # Helper function for reshape. Gets some factors of a larger number.
1029 @staticmethod
1030 def getFactors(val, start=1):
1031 factors = []
1032
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001033 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -07001034 if (val % i) == 0:
1035 factors.append(i)
1036
1037 return factors
1038
1039 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001040 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001041 arg_list = []
1042
1043 origShape = shapeList[0]
1044
1045 totalElements = 1
1046 for s in origShape:
1047 totalElements *= s
1048
1049 # This code is NOT fast. Fortunately, the numbers are fairly small.
1050 factors = TosaArgGen.getFactors(totalElements)
1051
1052 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +01001053 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001054 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -07001055 continue
1056
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001057 found = True
1058 # escape_counter breaks while loop if it continues on for too long
1059 escape_counter = 0
1060 while found:
1061 newShape = []
1062 # Generate newShape ensuring it isn't a duplicate
1063 remainingElements = totalElements
1064 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +01001065 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001066 # pick rank-1 factors
1067 newShape.append(shuffledFactors[0])
1068 remainingElements = remainingElements // shuffledFactors[0]
1069 shuffledFactors = testGen.rng.permutation(
1070 TosaArgGen.getFactors(remainingElements)
1071 )
1072 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -07001073
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001074 # Toss in a -1 sometimes
1075 minusOne = testGen.randInt(0, newRank * 4)
1076 if minusOne < newRank:
1077 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -07001078
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001079 # Check for duplicates
1080 found = False
1081 for name, other_shape in arg_list:
1082 if other_shape[0] == newShape:
1083 found = True
1084 break
1085
1086 escape_counter += 1
1087 if escape_counter >= 100:
1088 break
1089
1090 if not found:
1091 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001092
1093 return arg_list
1094
Eric Kunzee5e26762020-10-13 16:11:07 -07001095 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001096 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001097 arg_list = []
1098
1099 ifm_shape = shapeList[0]
1100
Matthew Haddone807aae2021-10-11 18:12:58 +01001101
1102 if error_name == ErrorIf.IndexOutsideBounds:
1103 incorrect_large_index = range(len(ifm_shape)+1, 2*len(ifm_shape)+1)
1104 incorrect_small_index = range(-len(ifm_shape), 0)
1105 permutations = [p for p in itertools.permutations(incorrect_large_index)]
1106 permutations.extend([p for p in itertools.permutations(incorrect_small_index)])
1107 elif error_name == ErrorIf.IndexUsedTwice:
1108 # Create list with a duplicated index
1109 perm_range = list(range(len(ifm_shape)))
1110 index_choice = testGen.rng.choice(range(len(perm_range)))
1111 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
1112 permutations = [p for p in itertools.permutations(perm_range)]
1113
1114
1115 else:
1116 # Get all permutations
1117 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -07001118
Jeremy Johnsona6185572021-06-21 15:55:35 +01001119 # Limit to possible permutations from shape dimension or argument setting
1120 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001121
Jeremy Johnsona6185572021-06-21 15:55:35 +01001122 # Get random permutation generator that uses all permutations
1123 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001124
Jeremy Johnsona6185572021-06-21 15:55:35 +01001125 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -07001126 arg_list = [
1127 ("perm{}".format(p), [random_permutations[p].tolist()])
1128 for p in range(limit)
1129 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07001130 return arg_list
1131
1132 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001133 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001134 arg_list = []
1135
1136 ifm_shape = shapeList[0]
1137 rank = len(ifm_shape)
1138
1139 for p in range(testGen.args.num_rand_permutations):
Matthew Haddone807aae2021-10-11 18:12:58 +01001140 start = []
Eric Kunzee5e26762020-10-13 16:11:07 -07001141 size = []
1142
Kevin Cheng550ccc52021-03-03 11:21:43 -08001143 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -07001144
1145 for i in range(rank):
1146 if ifm_shape[i] > 1:
Matthew Haddone807aae2021-10-11 18:12:58 +01001147 start.append(testGen.randInt(0, ifm_shape[i]))
1148 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001149
1150 # Invalid slice size?
1151 if size[i] == 0:
1152 valid = False
1153 else:
Matthew Haddone807aae2021-10-11 18:12:58 +01001154 start.append(0)
Eric Kunzee5e26762020-10-13 16:11:07 -07001155 size.append(1)
1156
1157 if valid:
Matthew Haddone807aae2021-10-11 18:12:58 +01001158 # If ERROR_IF test required then incorrect start, size will be returned
1159 start, size = TosaErrorIfArgGen.eiSliceErrorIf(testGen, error_name, ifm_shape, start, size)
1160 arg_list.append(("perm{}".format(p), [start, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001161 return arg_list
1162
1163 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001164 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001165 arg_list = []
1166
1167 ifm_shape = shapeList[0]
1168 rank = len(ifm_shape)
1169
1170 for p in range(testGen.args.num_rand_permutations):
1171
1172 # Pick a few random, but small multiple values
1173 # because otherwise this has a tendency to generate
1174 # enormous tensors
1175 multiples = []
1176 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +01001177 if ifm_shape[i] > 1000:
1178 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
1179 multiples.append(1)
1180 elif max(ifm_shape) > 1000:
1181 multiples.append(2)
1182 else:
1183 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001184 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001185
1186 return arg_list
1187
1188 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001189 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001190 arg_list = []
1191
1192 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +01001193 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001194
1195 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +01001196 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001197 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +01001198 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001199 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +01001200 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001201 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +01001202 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001203 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001204 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001205 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +01001206 elif error_name == ErrorIf.WrongInputType:
1207 # If an incorrect input type is used then we set a 'correct'
1208 # output type to avoid other errors
1209 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -07001210 else:
1211 continue
1212
1213 for outputDType in outputDTypeList:
1214 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -07001215 # Randomly generate legal output dimensions and shift
1216 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +01001217 # A output_dim of 1 will cause offset to exceed allowed range
1218 # so minimum value 2 produced below
1219 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
1220 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
1221 output_dims[0] += 1
1222 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
1223 output_dims[1] += 1
1224
Kevin Cheng77d0f762020-11-24 10:26:32 -08001225 in_center_h = (ifm_shape[1] - 1) / 2.0
1226 in_center_w = (ifm_shape[2] - 1) / 2.0
1227 out_center_h = (output_dims[0] - 1) / 2.0
1228 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -07001229
Kevin Cheng77d0f762020-11-24 10:26:32 -08001230 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
1231 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
1232 fp_offset_y = in_center_h - fp_stride_y * out_center_h
1233 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -07001234
Kevin Cheng77d0f762020-11-24 10:26:32 -08001235 if outputDType == DType.FLOAT:
1236 shift = 0
1237 stride = [0, 0]
1238 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001239 stride_fp = [fp_stride_y, fp_stride_x]
1240 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +01001241
1242 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001243 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001244 testGen,
1245 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001246 mode,
1247 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001248 shapeList,
1249 outputDType,
1250 shift,
1251 stride,
1252 stride_fp,
1253 offset,
1254 offset_fp
1255 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001256 else:
1257 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001258
Kevin Cheng550ccc52021-03-03 11:21:43 -08001259 arg_list.append(
1260 (
1261 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001262 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001263 output_dims[0],
1264 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001265 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001266 stride_fp[0],
1267 stride_fp[1],
1268 offset_fp[0],
1269 offset_fp[1],
1270 ),
1271 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001272 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001273 stride,
1274 offset,
1275 shift,
1276 stride_fp,
1277 offset_fp,
1278 output_dims,
1279 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001280 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001281 ],
1282 )
1283 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001284 else:
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001285 shift = testGen.randInt(1,12)
1286 # Now search for a shift value (1 to 11) that will produce
1287 # a valid and predictable resize operation
1288 count = 0
1289 while (count < 12):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001290 unit = float(1 << shift)
1291 stride_y = int(round(fp_stride_y * unit))
1292 stride_x = int(round(fp_stride_x * unit))
1293 offset_y = int(round(fp_offset_y * unit))
1294 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001295
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001296 if (
1297 stride_y >= (16 << shift)
1298 or stride_x >= (16 << shift)
1299 or offset_y >= (16 << shift)
1300 or offset_x >= (16 << shift)
1301 or offset_y <= (-16 << shift)
1302 or offset_x <= (-16 << shift)
1303 ):
1304 # Change the shift value and check again
1305 count += 1
1306 shift = (shift % 11) + 1
1307 continue
1308
1309 def RESIZE_REQUIRE_CALC(length_in, length_out, stride, offset, shift):
1310 # Perform the pseudo loop to look for out of bounds
1311 for pos in range(0,length_out):
1312 a = pos * stride + offset
1313 ia = a >> shift
1314 ia0 = max(ia, 0)
1315 ia1 = min(ia+1, length_in-1)
1316 if ia0 > ia1:
1317 # Found a problem value
1318 break
1319 return ia0, ia1
1320
1321 iy0, iy1 = RESIZE_REQUIRE_CALC(ifm_shape[1], output_dims[0], stride_y, offset_y, shift)
1322 ix0, ix1 = RESIZE_REQUIRE_CALC(ifm_shape[2], output_dims[1], stride_x, offset_x, shift)
1323 if ix0 > ix1 or iy0 > iy1:
1324 # Change the shift value and check again
1325 count += 1
1326 shift = (shift % 11) + 1
1327 continue
1328 break
1329
1330 if count >= 12:
1331 # Couldn't find a good set of values for this test, skip it
1332 continue
1333
Kevin Cheng550ccc52021-03-03 11:21:43 -08001334 stride = [stride_y, stride_x]
1335 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001336
1337 stride_fp = [0.0, 0.0]
1338 offset_fp = [0.0, 0.0]
1339
Matthew Haddone86fd342021-09-07 16:12:21 +01001340 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001341 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001342 testGen,
1343 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001344 mode,
1345 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001346 shapeList,
1347 outputDType,
1348 shift,
1349 stride,
1350 stride_fp,
1351 offset,
1352 offset_fp
1353 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001354 else:
1355 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001356
Kevin Cheng550ccc52021-03-03 11:21:43 -08001357 arg_list.append(
1358 (
1359 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001360 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001361 shift,
1362 output_dims[0],
1363 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001364 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001365 stride[0],
1366 stride[1],
1367 offset[0],
1368 offset[1],
1369 ),
1370 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001371 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001372 stride,
1373 offset,
1374 shift,
1375 stride_fp,
1376 offset_fp,
1377 output_dims,
1378 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001379 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001380 ],
1381 )
1382 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001383
1384 return arg_list
1385
Kevin Chengfe392ce2021-10-18 21:51:55 +00001386 @staticmethod
1387 def agTable(testGen, opName, shapeList, dtype, error_name=None):
1388 arg_list = []
1389
1390 if dtype == DType.INT8:
1391 table = np.int32(
1392 testGen.rng.integers(low=-128, high=128, size=[256])
1393 ).tolist()
1394 else: # INT16
1395 table = np.int32(
1396 testGen.rng.integers(low=-32768, high=32768, size=[513])
1397 ).tolist()
1398
1399 arg_list.append(
1400 (
1401 "",
1402 [table],
1403 )
1404 )
1405 return arg_list
1406
Matthew Haddon1c00b712021-10-01 15:51:03 +01001407 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001408 # CondIf generates the condition values here.
1409 # Convert to tensors in the build function, along with the
1410 # then and else blocks
1411 arg_list = []
1412
1413 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001414 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001415
1416 return arg_list
1417
Matthew Haddon1c00b712021-10-01 15:51:03 +01001418 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001419 # While loop: 0 iterations, 1, more than 1
1420 arg_list = []
1421
1422 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001423 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001424
1425 return arg_list
1426
Matthew Haddone86fd342021-09-07 16:12:21 +01001427class TosaErrorIfArgGen:
1428
1429 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001430 def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
Matthew Haddone86fd342021-09-07 16:12:21 +01001431
1432 if outputDType == DType.FLOAT:
1433 if error_name == ErrorIf.StrideSmallerEqualZero:
1434 stride_fp = testGen.rng.random(size=[2]) - 2
1435 elif error_name == ErrorIf.ShiftNotZero:
1436 shift = testGen.rng.integers(1, 5)
1437 elif error_name == ErrorIf.StrideLargerDimension:
1438 shape = shapeList[0]
1439 transform_height = testGen.rng.choice([False, True])
1440 if transform_height:
1441 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1442 else:
1443 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1444 else:
1445 if error_name == ErrorIf.StrideSmallerEqualZero:
1446 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1447 elif error_name == ErrorIf.ShiftSmallerOne:
1448 shift = testGen.rng.integers(-3, 1)
1449 if shift <= 0:
1450 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1451 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1452 else:
1453 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1454 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1455 elif error_name == ErrorIf.ShiftLargerEleven:
1456 shift = np.int16(testGen.rng.integers(12, 15))
1457 elif error_name == ErrorIf.StrideLargerDimension:
1458 shape = shapeList[0]
1459 transform_height = testGen.rng.choice([False, True])
1460 if transform_height:
1461 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1462 else:
1463 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1464 elif error_name == ErrorIf.StrideLargerEqualMax:
1465 stride = [(16 << shift) + 1, (16 << shift) + 1]
1466 elif error_name == ErrorIf.OffsetLargerEqualMax:
1467 offset = [(16 << shift) + 1, (16 << shift) + 1]
1468 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1469 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1470
Matthew Haddon1c00b712021-10-01 15:51:03 +01001471
Matthew Haddon848efb42021-09-09 12:30:53 +01001472 if error_name == ErrorIf.WrongOutputType:
1473 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1474 incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1475 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1476 incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
1477 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1478 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
1479 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1480 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
1481 elif dtype == DType.FLOAT:
1482 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
1483 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001484
Matthew Haddon848efb42021-09-09 12:30:53 +01001485 return shift, stride, stride_fp, offset, offset_fp, outputDType
1486
Matthew Haddone807aae2021-10-11 18:12:58 +01001487
Matthew Haddon848efb42021-09-09 12:30:53 +01001488 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001489 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
1490 if (error_name == ErrorIf.StrideSmallerOne
1491 # padding must not exceed the kernel size
1492 and pad[0] < kernel[0] and pad[1] < kernel[0] and pad[2] < kernel[1] and pad[3] < kernel[1]):
1493 wrongStride = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1494 return wrongStride, pad, kernel
1495 elif error_name == ErrorIf.PadSmallerZero:
1496 wrongPad = (testGen.rng.choice([-1, -2, -3]),
1497 testGen.rng.choice([-1, -2, -3]),
1498 testGen.rng.choice([-1, -2, -3]),
1499 testGen.rng.choice([-1, -2, -3]))
1500 return stride, wrongPad, kernel
1501 elif error_name == ErrorIf.KernelSmallerOne:
1502 wrongKernel = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1503 return stride, pad, wrongKernel
1504 elif error_name == ErrorIf.PadLargerEqualKernel:
1505 wrongPad = (testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1506 testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1507 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]),
1508 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]))
1509 return stride, wrongPad, kernel
1510 else:
1511 return None, None, None
1512
Matthew Haddone807aae2021-10-11 18:12:58 +01001513
Matthew Haddonc2025212021-10-08 21:21:05 +01001514 @staticmethod
1515 def eiRescaleWrongOutputType(input_dtype, output_dtype):
1516 if input_dtype == DType.INT8:
1517 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1518 return True
1519 if input_dtype in [DType.INT16, DType.INT32]:
1520 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1521 return True
1522 elif input_dtype == DType.INT48:
1523 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1524 return True
1525 elif input_dtype == DType.UINT8:
1526 if output_dtype != DType.INT8:
1527 return True
1528 return False
1529
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001530
1531 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001532 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1533 # Mess up input/output tensors for ERROR_IF checks
1534 if error_name == "WrongInputList":
1535 add_input = testGen.rng.choice([True, False])
1536 if add_input:
1537 input_list.append('eiDummyInput')
1538 else:
1539 input_list = input_list[:-1]
Les Bell0e027d42021-11-09 14:42:14 +00001540 elif error_name == "WrongOutputList":
Matthew Haddon848efb42021-09-09 12:30:53 +01001541 add_output = testGen.rng.choice([True, False])
1542 if add_output:
1543 output_list.append('eiDummyOutput')
1544 else:
1545 output_list = []
1546 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001547
Matthew Haddonc2025212021-10-08 21:21:05 +01001548 @staticmethod
Matthew Haddon630c17c2021-10-14 15:05:41 +01001549 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
1550 """Restrict the dimensions and overall size of a shape to max_dim and max_items."""
1551 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
1552 while product(new_shape) > max_items:
1553 new_shape = [max(d - 1, 1) for d in new_shape]
1554 return new_shape
Matthew Haddone807aae2021-10-11 18:12:58 +01001555
1556 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
1557 if error_name == ErrorIf.StartSmallerZero:
1558 newStart = []
1559 for i in range(len(input_shape)):
1560 newStart.append(testGen.rng.choice([-3, -2, -1]))
1561 return newStart, size
1562 elif error_name == ErrorIf.SizeSmallerEqualZero:
1563 newSize = []
1564 for i in range(len(input_shape)):
1565 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
1566 return start, newSize
1567 elif error_name == ErrorIf.StartSizeOutsideBounds:
1568 newStart, newSize = [], []
1569 for i in range(len(input_shape)):
1570 newStart.append(input_shape[i]-1)
1571 newSize.append(testGen.rng.choice([2, 3, 4]))
1572 return newStart, newSize
1573 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
1574 remove = testGen.rng.choice([True, False])
1575 if remove:
1576 newStart = start[1:]
1577 newSize = size[1:]
1578 else:
1579 newStart = start
1580 newStart.append(1)
1581 newSize = size
1582 newSize.append(1)
1583 return newStart, newSize
1584 else:
1585 return start, size
1586
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001587 @staticmethod
1588 def eiCastErrorIf(testGen, input_dtype):
1589 if input_dtype in [DType.BOOL, DType.FLOAT]:
1590 outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
1591 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
1592 outputDType = [DType.INT48]
1593 else:
1594 assert True, f"input_dtype ({input_dtype}) not supported"
1595 return outputDType
1596
1597
Matthew Haddone86fd342021-09-07 16:12:21 +01001598class TosaErrorValidator:
1599
Matthew Haddon848efb42021-09-09 12:30:53 +01001600 @staticmethod
1601 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
1602 # Check ERROR_IF statements
Matthew Haddon848efb42021-09-09 12:30:53 +01001603 for val_fcn in validator_fcns:
1604 val_result = val_fcn(True, **kwargs)
Matthew Haddon848efb42021-09-09 12:30:53 +01001605 validator_name = val_result['error_name']
1606 error_result = val_result['error_result']
1607 error_reason = val_result['error_reason']
1608
Les Bell0e027d42021-11-09 14:42:14 +00001609 # expect an error IFF the error_name and validator_name match
1610 expected_result = error_result == (error_name == validator_name)
1611
1612 if expected_result and error_result:
1613 serializer.setExpectedReturnCode(2, error_reason)
1614 elif error_result: # and not expected_result
1615 print(f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
1616 f" Expected: {error_name}, Got: {validator_name}")
1617 elif not expected_result: # and not error_result
1618 print(f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
1619 f" Expected: {error_name}")
1620
1621 if not expected_result:
1622 for k, v in sorted(kwargs.items()):
1623 if k != 'op':
1624 if k.endswith('dtype'):
1625 v = valueToName(DType, v)
1626 print(f' {k} = {v}')
Matthew Haddon848efb42021-09-09 12:30:53 +01001627
1628 @staticmethod
1629 def evWrongInputType(check=False, **kwargs):
Les Bell0e027d42021-11-09 14:42:14 +00001630 error_result = False
Matthew Haddon848efb42021-09-09 12:30:53 +01001631
1632 # Find the unsupported input data types
Matthew Haddon848efb42021-09-09 12:30:53 +01001633 op = kwargs['op']
1634 input_dtypes = op['types']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001635 allowed_input_dtypes = {t[0] if isinstance(t, list) else t for t in input_dtypes}
Les Bell0e027d42021-11-09 14:42:14 +00001636 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
Matthew Haddon848efb42021-09-09 12:30:53 +01001637
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001638 if op['op'] == Op.CLAMP:
1639 wrong_input_dtypes.remove(DType.INT48)
1640
Matthew Haddon848efb42021-09-09 12:30:53 +01001641 if check:
1642 input_dtype = kwargs['input_dtype']
Les Bell0e027d42021-11-09 14:42:14 +00001643 if input_dtype not in allowed_input_dtypes:
Matthew Haddon848efb42021-09-09 12:30:53 +01001644 error_result = True
1645
1646 info_dict = {
Les Bell0e027d42021-11-09 14:42:14 +00001647 "error_name": ErrorIf.WrongInputType,
Matthew Haddon848efb42021-09-09 12:30:53 +01001648 "error_result": error_result,
Les Bell0e027d42021-11-09 14:42:14 +00001649 "error_reason": f"Input data type not supported for this operator",
1650 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
Matthew Haddon848efb42021-09-09 12:30:53 +01001651 }
1652 return info_dict
1653
1654 @staticmethod
1655 def evWrongOutputType(check=False, **kwargs):
Matthew Haddon848efb42021-09-09 12:30:53 +01001656 error_result = False
Matthew Haddon848efb42021-09-09 12:30:53 +01001657
1658 if check:
1659 input_dtype = kwargs['input_dtype']
1660 output_dtype = kwargs['output_dtype']
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001661 op = kwargs['op']
Matthew Haddon848efb42021-09-09 12:30:53 +01001662
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001663 if op['op'] == Op.RESIZE:
1664 mode = kwargs['mode']
1665 if (
1666 (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
1667 (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
1668 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1669 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1670 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1671 ):
1672 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001673
Matthew Haddonc2025212021-10-08 21:21:05 +01001674 elif op['op'] == Op.RESCALE:
1675 if input_dtype == DType.INT8:
1676 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1677 error_result = True
1678 if input_dtype in [DType.INT16, DType.INT32]:
1679 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1680 error_result = True
1681 elif input_dtype == DType.INT48:
1682 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1683 error_result = True
1684 elif input_dtype == DType.UINT8:
1685 if output_dtype != DType.INT8:
1686 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001687
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001688 elif op['op'] in [Op.FULLY_CONNECTED, Op.MATMUL]:
1689 if (
1690 (input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1691 (input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1692 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1693 ):
1694 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001695
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001696 elif op['op'] == Op.ARGMAX:
1697 if input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] and output_dtype != DType.INT32:
1698 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001699
1700 elif op['op'] == Op.MUL:
1701 if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
1702 error_result = True
1703 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1704 error_result = True
1705
1706 elif op['op'] == Op.TABLE:
1707 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
1708 error_result = True
1709 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
1710 error_result = True
1711
1712 elif op['op'] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
1713 if output_dtype != DType.BOOL:
1714 error_result = True
1715
1716 elif op['op'] == Op.CAST:
1717 if (
1718 (input_dtype == DType.BOOL and output_dtype not in [DType.INT8, DType.INT16, DType.INT32])
1719 or (input_dtype == DType.INT8 and output_dtype not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT])
1720 or (input_dtype == DType.INT16 and output_dtype not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT])
1721 or (input_dtype == DType.INT32 and output_dtype not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT])
1722 or (input_dtype == DType.FLOAT and output_dtype not in [DType.INT8, DType.INT16, DType.INT32])
1723 ):
1724 error_result = True
1725
Les Bell0e027d42021-11-09 14:42:14 +00001726 elif op['op'] in {Op.CONV2D, Op.CONV3D, Op.DEPTHWISE_CONV2D, Op.TRANSPOSE_CONV2D}:
1727 if (
1728 input_dtype == DType.INT8 and output_dtype != DType.INT32
1729 or input_dtype == DType.INT16 and output_dtype != DType.INT48
1730 or input_dtype == DType.FLOAT and output_dtype != DType.FLOAT
1731 ):
1732 error_result = True
1733 # invalid input types are ignored, to avoid reporting multiple errors
1734
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001735 else:
1736 if output_dtype != input_dtype:
1737 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001738
1739 info_dict = {
Les Bell0e027d42021-11-09 14:42:14 +00001740 "error_name": ErrorIf.WrongOutputType,
Matthew Haddon848efb42021-09-09 12:30:53 +01001741 "error_result": error_result,
Les Bell0e027d42021-11-09 14:42:14 +00001742 "error_reason": "Output data type not supported for this configuration of operator",
1743 "param_reqs": {"rank": None, "dtype": None, "shape": None}
Matthew Haddon848efb42021-09-09 12:30:53 +01001744 }
1745 return info_dict
1746
1747 @staticmethod
1748 def evWrongRank(check=False, **kwargs):
1749 all_ranks = (1, 2, 3, 4, 5)
1750
1751 # Make a list of incorrect ranks
1752 assert 'op' in kwargs
1753 op = kwargs['op']
1754 rmin, rmax = op['rank']
1755 rank_range = range(rmin, rmax + 1)
1756 incorrect_ranks = list(set(all_ranks) - set(rank_range))
Matthew Haddonc2025212021-10-08 21:21:05 +01001757 # Remove small incorrect ranks to avoid index errors
1758 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
Matthew Haddon848efb42021-09-09 12:30:53 +01001759 # Set minimum incorrect rank to 3 to avoid index error
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001760 if op['op'] in [Op.RESIZE]:
Matthew Haddon848efb42021-09-09 12:30:53 +01001761 incorrect_ranks = [3, 5]
Les Bell0e027d42021-11-09 14:42:14 +00001762 elif op['op'] in [Op.TRANSPOSE]:
Matthew Haddon01c359d2021-10-15 16:30:48 +01001763 incorrect_ranks = [7, 8]
Les Bell0e027d42021-11-09 14:42:14 +00001764 elif op['op'] in [Op.CONV3D]:
1765 incorrect_ranks = [6, 7]
Matthew Haddon848efb42021-09-09 12:30:53 +01001766
1767 error_name = ErrorIf.WrongRank
1768 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1769 error_result = False
1770 error_reason = "Rank not supported for this operator"
1771
1772 if check:
1773 input_shape = kwargs['input_shape']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001774
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001775 if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4:
Matthew Haddon848efb42021-09-09 12:30:53 +01001776 error_result = True
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001777 elif op['op'] == Op.FULLY_CONNECTED and len(input_shape) != 2:
1778 error_result = True
1779 elif op['op'] == Op.MATMUL and len(input_shape) != 3:
1780 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001781 else:
1782 if len(input_shape) not in rank_range:
1783 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001784
1785 info_dict = {
1786 "error_name": error_name,
1787 "error_result": error_result,
1788 "error_reason": error_reason,
1789 "param_reqs": param_reqs
1790 }
1791 return info_dict
1792
1793 @staticmethod
1794 def evWrongInputList(check=False, **kwargs):
1795 error_name = ErrorIf.WrongInputList
1796 param_reqs = {"rank": None, "dtype": None, "shape": None}
1797 error_result = False
1798 error_reason = "Op input list does not match expected input"
1799
1800 if check:
1801 op = kwargs['op']
1802 input_list = kwargs['input_list']
1803 num_operands = kwargs['num_operands']
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001804 if op['op'] in [Op.SCATTER, Op.GATHER]:
1805 # SCATTER/GATHER add an indices input tensor in their build functions
1806 num_operands += 1
Kevin Chengfe392ce2021-10-18 21:51:55 +00001807 if len(input_list) != num_operands:
1808 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001809
1810 info_dict = {
1811 "error_name": error_name,
1812 "error_result": error_result,
1813 "error_reason": error_reason,
1814 "param_reqs": param_reqs
1815 }
1816 return info_dict
1817
1818 @staticmethod
1819 def evWrongOutputList(check=False, **kwargs):
1820 error_name = ErrorIf.WrongOutputList
1821 param_reqs = {"rank": None, "dtype": None, "shape": None}
1822 error_result = False
1823 error_reason = "Op output list does not match expected output"
1824
1825 if check:
1826 output_list = kwargs['output_list']
1827 # Note this will be incorrect if an operator returns more than one output
1828 if len(output_list) != 1:
1829 error_result = True
1830
1831 info_dict = {
1832 "error_name": error_name,
1833 "error_result": error_result,
1834 "error_reason": error_reason,
1835 "param_reqs": param_reqs
1836 }
1837 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01001838
1839 @staticmethod
1840 def evMaxDimExceeded(check=False, **kwargs):
1841 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01001842 param_reqs = {
1843 "rank": [4,4],
1844 "dtype": [DType.INT8],
1845 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
1846 }
Matthew Haddone86fd342021-09-07 16:12:21 +01001847 error_result = False
1848 error_reason = "At least one maximum dimension is larger than 16384"
1849
1850 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001851 input_shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001852 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
1853 if ((input_shape[1] > 16384) or
1854 (input_shape[2] > 16384) or
1855 (output_shape[0] > 16384) or
1856 (output_shape[1] > 16384)):
1857 error_result = True
1858
1859 info_dict = {
1860 "error_name": error_name,
1861 "error_result": error_result,
1862 "error_reason": error_reason,
1863 "param_reqs": param_reqs
1864 }
1865 return info_dict
1866
1867 @staticmethod
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001868 def evBatchMismatch(check=False, **kwargs):
1869 error_name = ErrorIf.BatchMismatch
1870 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1871 error_result = False
1872 error_reason = "Input batch size not equal to output batch size"
1873
1874 assert 'op' in kwargs
1875 op = kwargs['op']
1876 rmin, rmax = op['rank']
1877 rank_range = range(rmin, rmax + 1)
1878
1879 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001880 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001881 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1882
1883 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
1884 error_result = True
1885
1886 info_dict = {
1887 "error_name": error_name,
1888 "error_result": error_result,
1889 "error_reason": error_reason,
1890 "param_reqs": param_reqs
1891 }
1892 return info_dict
1893
1894 @staticmethod
1895 def evChannelMismatch(check=False, **kwargs):
1896 error_name = ErrorIf.ChannelMismatch
1897 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1898 error_result = False
1899 error_reason = "Input channel size not equal to output channel size"
1900
1901 assert 'op' in kwargs
1902 op = kwargs['op']
1903 rmin, rmax = op['rank']
1904 rank_range = range(rmin, rmax + 1)
1905
1906 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001907 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001908 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1909 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
1910 error_result = True
1911
1912 info_dict = {
1913 "error_name": error_name,
1914 "error_result": error_result,
1915 "error_reason": error_reason,
1916 "param_reqs": param_reqs
1917 }
1918 return info_dict
1919
1920 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001921 def evStrideSmallerEqualZero(check=False, **kwargs):
1922 error_name = ErrorIf.StrideSmallerEqualZero
1923 param_reqs = {"rank": None, "dtype": None, "shape": None}
1924 error_result = False
1925 error_reason = "Stride value smaller than or equal zero"
1926
1927 if check:
1928 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001929 output_dtype = kwargs['output_dtype']
1930 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
1931 stride = kwargs['stride'] # Work around wrong input/output type tests
1932 elif output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001933 stride = kwargs['stride_fp']
Matthew Haddon848efb42021-09-09 12:30:53 +01001934 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1935 stride = kwargs['stride_fp'] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01001936 else:
1937 stride = kwargs['stride']
1938
1939 if min(stride) <= 0:
1940 error_result = True
1941
1942 info_dict = {
1943 "error_name": error_name,
1944 "error_result": error_result,
1945 "error_reason": error_reason,
1946 "param_reqs": param_reqs
1947 }
1948 return info_dict
1949
1950 @staticmethod
1951 def evStrideLargerEqualMax(check=False, **kwargs):
1952 error_name = ErrorIf.StrideLargerEqualMax
1953 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1954 error_result = False
1955 error_reason = "Stride value larger than or equal to maximum value"
1956
1957 if check:
1958 shift = kwargs['shift']
1959 input_dtype = kwargs['input_dtype']
1960 stride = kwargs['stride']
1961 if input_dtype in [DType.INT8, DType.INT16]:
1962 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1963 error_result = True
1964 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1965 error_result = True
1966
1967 info_dict = {
1968 "error_name": error_name,
1969 "error_result": error_result,
1970 "error_reason": error_reason,
1971 "param_reqs": param_reqs
1972 }
1973 return info_dict
1974
1975
1976 @staticmethod
1977 def evStrideLargerDimension(check=False, **kwargs):
1978 error_name = ErrorIf.StrideLargerDimension
1979 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1980 error_result = False
1981 error_reason = "Stride value larger than or equal to H/W dimension"
1982
1983 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001984 shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001985 input_dtype = kwargs['input_dtype']
1986 stride = kwargs['stride_fp']
1987
1988 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1989 error_result = True
1990
1991 info_dict = {
1992 "error_name": error_name,
1993 "error_result": error_result,
1994 "error_reason": error_reason,
1995 "param_reqs": param_reqs
1996 }
1997 return info_dict
1998
1999
2000 @staticmethod
2001 def evOffsetSmallerEqualMin(check=False, **kwargs):
2002 error_name = ErrorIf.OffsetSmallerEqualMin
2003 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2004 error_result = False
2005 error_reason = "Offset value smaller than or equal to minimum value"
2006
2007 if check:
2008 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01002009 output_dtype = kwargs['output_dtype']
2010 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01002011 offset = kwargs['offset_fp']
2012 else:
2013 offset = kwargs['offset']
2014
2015 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
2016 error_result = True
2017 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
2018 error_result = True
2019
2020 info_dict = {
2021 "error_name": error_name,
2022 "error_result": error_result,
2023 "error_reason": error_reason,
2024 "param_reqs": param_reqs
2025 }
2026 return info_dict
2027
2028 @staticmethod
2029 def evOffsetLargerEqualMax(check=False, **kwargs):
2030 error_name = ErrorIf.OffsetLargerEqualMax
2031 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2032 error_result = False
2033 error_reason = "Offset value larger than or equal to maximum value"
2034
2035 if check:
2036 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01002037 output_dtype = kwargs['output_dtype']
2038 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01002039 offset = kwargs['offset_fp']
2040 else:
2041 offset = kwargs['offset']
2042
2043 if shift >= 0:
2044 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
2045 error_result = True
2046
2047 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
2048 error_result = True
2049 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
2050 error_result = True
2051
2052 info_dict = {
2053 "error_name": error_name,
2054 "error_result": error_result,
2055 "error_reason": error_reason,
2056 "param_reqs": param_reqs
2057 }
2058 return info_dict
2059
2060 @staticmethod
2061 def evShiftNotZero(check=False, **kwargs):
2062 error_name = ErrorIf.ShiftNotZero
2063 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
2064 error_result = False
2065 error_reason = "Shift value must be zero for float input"
2066
2067 if check:
2068 shift = kwargs['shift']
2069 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01002070 output_dtype = kwargs['output_dtype']
2071 if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
Matthew Haddone86fd342021-09-07 16:12:21 +01002072 error_result = True
2073
2074 info_dict = {
2075 "error_name": error_name,
2076 "error_result": error_result,
2077 "error_reason": error_reason,
2078 "param_reqs": param_reqs
2079 }
2080 return info_dict
2081
2082
2083 @staticmethod
2084 def evShiftSmallerOne(check=False, **kwargs):
2085 error_name = ErrorIf.ShiftSmallerOne
2086 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2087 error_result = False
2088 error_reason = "Shift value smaller than one"
2089
2090 if check:
2091 shift = kwargs['shift']
2092 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01002093 output_dtype = kwargs['output_dtype']
2094 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01002095 error_result = True
2096
2097 info_dict = {
2098 "error_name": error_name,
2099 "error_result": error_result,
2100 "error_reason": error_reason,
2101 "param_reqs": param_reqs
2102 }
2103 return info_dict
2104
2105 @staticmethod
2106 def evShiftLargerEleven(check=False, **kwargs):
2107 error_name = ErrorIf.ShiftLargerEleven
2108 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2109 error_result = False
2110 error_reason = "Shift value larger than eleven"
2111
2112 if check:
2113 shift = kwargs['shift']
2114 if shift > 11:
2115 error_result = True
2116
2117 info_dict = {
2118 "error_name": error_name,
2119 "error_result": error_result,
2120 "error_reason": error_reason,
2121 "param_reqs": param_reqs
2122 }
2123 return info_dict
2124
2125
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002126 @staticmethod
2127 def evRankMismatch(check=False, **kwargs):
2128 error_name = ErrorIf.RankMismatch
2129 param_reqs = {"rank": None, "dtype": None, "shape": None}
2130 error_result = False
2131 error_reason = "Input Rank does not match output rank"
2132
2133 if check:
2134 input1_shape = kwargs['input1'].shape
2135 input2_shape = kwargs['input2'].shape
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002136 # In case of SELECT op
2137 input3_shape = kwargs['input3'].shape if 'input3' in kwargs else input2_shape
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002138 output_shape = kwargs['result_tensor'].shape
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002139 if (
2140 (len(input1_shape) != len(output_shape)) or
2141 (len(input2_shape) != len(output_shape)) or
2142 (len(input3_shape) != len(output_shape))
2143 ):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002144 error_result = True
2145
2146 info_dict = {
2147 "error_name": error_name,
2148 "error_result": error_result,
2149 "error_reason": error_reason,
2150 "param_reqs": param_reqs
2151 }
2152 return info_dict
2153
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002154 @staticmethod
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002155 def evDimensionMismatch(check=False, **kwargs):
2156 error_name = ErrorIf.DimensionMismatch
2157 param_reqs = {"rank": None, "dtype": None, "shape": None}
2158 error_result = False
2159 error_reason = "Input Dimensions do not match output"
2160
2161 if check:
2162 input1_shape = kwargs['input1'].shape
2163 input2_shape = kwargs['input2'].shape
2164 # In case of SELECT op
2165 input3_shape = kwargs['input3'].shape if 'input3' in kwargs else input2_shape
2166 output_shape = kwargs['result_tensor'].shape
2167 for i in range(min(len(input1_shape), len(input2_shape), len(input3_shape))):
2168 if (
2169 (input1_shape[i] != 1 and input1_shape[i] != output_shape[i]) or
2170 (input2_shape[i] != 1 and input2_shape[i] != output_shape[i]) or
2171 (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
2172 ):
2173 error_result = True
2174
2175 info_dict = {
2176 "error_name": error_name,
2177 "error_result": error_result,
2178 "error_reason": error_reason,
2179 "param_reqs": param_reqs
2180 }
2181 return info_dict
2182
2183 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002184 def evInputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002185 op = kwargs['op']
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002186 error_result = False
Les Bell0e027d42021-11-09 14:42:14 +00002187
2188 # Quantizable types
2189 qTypes = (DType.INT8, DType.UINT8)
2190
2191 # This does not apply to quantizable types
2192 inputDtypes = [
2193 dtype for dtype in op['types']
2194 if (isinstance(dtype, list) and dtype[0] not in qTypes) or
2195 (not isinstance(dtype, list) and dtype not in qTypes)
2196 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002197
2198 if check:
2199 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01002200 if isinstance(kwargs['qinfo'], tuple):
2201 qinfo = kwargs['qinfo']
2202 input_zero_point = qinfo[0]
2203 else:
2204 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
2205 qinfo = kwargs['qinfo'].ints
2206 input_zero_point = qinfo[0][1]
2207
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002208 if op['op'] == Op.MATMUL:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002209 qinfo = kwargs['qinfo'].ints
Les Bell0e027d42021-11-09 14:42:14 +00002210 for dtype, zp in (
2211 (kwargs['input_dtype'], qinfo[0][1]),
2212 (kwargs['input2_dtype'], qinfo[1][1]),
2213 ):
2214 if dtype not in qTypes and zp != 0:
2215 error_result = True
2216 break
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002217 else:
Les Bell0e027d42021-11-09 14:42:14 +00002218 error_result = input_dtype not in qTypes and input_zero_point != 0
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002219
2220 info_dict = {
Les Bell0e027d42021-11-09 14:42:14 +00002221 "error_name": ErrorIf.InputZeroPointNotZero,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002222 "error_result": error_result,
Les Bell0e027d42021-11-09 14:42:14 +00002223 "error_reason": "Input DType not INT8 and zero point not 0",
2224 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None}
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002225 }
2226 return info_dict
2227
2228
2229 @staticmethod
2230 def evWeightZeroPointNotZero(check=False, **kwargs):
2231 op = kwargs['op']
2232
2233 # exclude inputs with INT8 weights
2234 inputDtypes = [t for t in op['types']
2235 if not isinstance(t, list) or t[1] != DType.INT8]
2236
2237 error_name = ErrorIf.WeightZeroPointNotZero
2238 param_reqs = {
2239 "rank": None,
2240 "dtype": inputDtypes,
2241 "shape": None
2242 }
2243 error_result = False
2244 error_reason = "Weight DType not INT8 and zero point not 0"
2245
2246 if check:
2247 weight_dtype = kwargs['weight_dtype']
2248 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
2249 qinfo = kwargs['qinfo'].ints
2250 weight_zero_point = qinfo[1][1]
2251 if weight_dtype != DType.INT8 and weight_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002252 error_result = True
2253
2254 info_dict = {
2255 "error_name": error_name,
2256 "error_result": error_result,
2257 "error_reason": error_reason,
2258 "param_reqs": param_reqs
2259 }
2260 return info_dict
2261
2262
2263 @staticmethod
2264 def evOutputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002265 op = kwargs['op']
2266 inputDtypes = op['types'].copy()
2267 if DType.INT8 in inputDtypes:
2268 inputDtypes.remove(DType.INT8)
2269 if DType.UINT8 in inputDtypes:
2270 inputDtypes.remove(DType.UINT8)
2271
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002272 error_name = ErrorIf.OutputZeroPointNotZero
2273 param_reqs = {
2274 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002275 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002276 "shape": None
2277 }
2278 error_result = False
2279 error_reason = "Output DType not INT8 and zero point not 0"
2280
2281 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002282 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01002283 output_dtype = kwargs['output_dtype']
2284 if isinstance(kwargs['qinfo'], tuple):
2285 qinfo = kwargs['qinfo']
2286 output_zero_point = qinfo[1]
2287 else:
2288 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
2289 qinfo = kwargs['qinfo'].ints
2290 output_zero_point = qinfo[1][1]
2291 if op['op'] == Op.AVG_POOL2D:
2292 if input_dtype != DType.INT8 and output_zero_point != 0:
2293 error_result = True
2294 elif output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002295 error_result = True
2296
2297 info_dict = {
2298 "error_name": error_name,
2299 "error_result": error_result,
2300 "error_reason": error_reason,
2301 "param_reqs": param_reqs
2302 }
2303 return info_dict
2304
Matthew Haddond6ce7252021-09-29 15:35:44 +01002305 @staticmethod
2306 def evAxisSmallerZero(check=False, **kwargs):
2307 error_name = ErrorIf.AxisSmallerZero
2308 param_reqs = {"rank": None, "dtype": None, "shape": None}
2309 error_result = False
2310 error_reason = "Axis smaller than zero"
2311
2312 if check:
2313 axis = kwargs['axis']
2314 if axis < 0:
2315 error_result = True
2316
2317 info_dict = {
2318 "error_name": error_name,
2319 "error_result": error_result,
2320 "error_reason": error_reason,
2321 "param_reqs": param_reqs
2322 }
2323 return info_dict
2324
2325
2326 @staticmethod
2327 def evAxisLargerRank(check=False, **kwargs):
2328 error_name = ErrorIf.AxisLargerRank
2329 param_reqs = {"rank": None, "dtype": None, "shape": None}
2330 error_result = False
2331 error_reason = "Axis larger than rank"
2332
2333 if check:
2334 axis = kwargs['axis']
2335 shape = kwargs['input_shape']
2336 if axis > len(shape):
2337 error_result = True
2338
2339 info_dict = {
2340 "error_name": error_name,
2341 "error_result": error_result,
2342 "error_reason": error_reason,
2343 "param_reqs": param_reqs
2344 }
2345 return info_dict
2346
2347
2348 @staticmethod
2349 def evShapeOfAxisNotOne(check=False, **kwargs):
2350 error_name = ErrorIf.ShapeOfAxisNotOne
2351 param_reqs = {"rank": None, "dtype": None, "shape": None}
2352 error_result = False
2353 error_reason = "shape[axis] is not equal to 1"
2354
2355 if check:
2356 axis = kwargs['axis']
2357 shape = kwargs['output_shape']
2358 if (0 <= axis < len(shape)) and shape[axis] != 1:
2359 error_result = True
2360
2361 info_dict = {
2362 "error_name": error_name,
2363 "error_result": error_result,
2364 "error_reason": error_reason,
2365 "param_reqs": param_reqs
2366 }
2367 return info_dict
2368
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002369
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002370 @staticmethod
2371 def evPadSmallerZero(check=False, **kwargs):
2372 error_name = ErrorIf.PadSmallerZero
2373 param_reqs = {"rank": None, "dtype": None, "shape": None}
2374 error_result = False
2375 error_reason = "At least one pad is smaller than zero"
2376
2377 if check:
Matthew Haddone807aae2021-10-11 18:12:58 +01002378 op = kwargs['op']
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002379 pad = kwargs['pad']
Matthew Haddone807aae2021-10-11 18:12:58 +01002380 if op['op'] == Op.PAD:
2381 for padding in pad:
2382 if min(padding) < 0:
2383 error_result = True
2384 else:
2385 if min(pad) < 0:
2386 error_result = True
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002387
2388 info_dict = {
2389 "error_name": error_name,
2390 "error_result": error_result,
2391 "error_reason": error_reason,
2392 "param_reqs": param_reqs
2393 }
2394 return info_dict
2395
2396
2397 @staticmethod
2398 def evPadLargerEqualKernel(check=False, **kwargs):
2399 error_name = ErrorIf.PadLargerEqualKernel
2400 param_reqs = {"rank": None, "dtype": None, "shape": None}
2401 error_result = False
2402 error_reason = "At least one pad is larger than kernel dimension"
2403
2404 if check:
2405 pad = kwargs['pad']
2406 kernel = kwargs['kernel']
2407 if min(pad) > 0 and min(kernel) > 1:
2408 if pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]:
2409 error_result = True
2410
2411 info_dict = {
2412 "error_name": error_name,
2413 "error_result": error_result,
2414 "error_reason": error_reason,
2415 "param_reqs": param_reqs
2416 }
2417 return info_dict
2418
2419 @staticmethod
2420 def evPoolingOutputShapeMismatch(check=False, **kwargs):
2421 error_name = ErrorIf.PoolingOutputShapeMismatch
2422 param_reqs = {"rank": None, "dtype": None, "shape": None}
2423 error_result = False
2424 error_reason = "Mismatch between output shape provided and expected output shape"
2425
2426 if check:
2427 pad = kwargs['pad']
2428 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
2429
2430 kernel = kwargs['kernel']
2431 kernel_y, kernel_x = kernel[0], kernel[1]
2432
2433 input_shape = kwargs['input_shape']
2434 IH, IW = input_shape[1], input_shape[2]
2435
2436 output_shape = kwargs['output_shape']
2437 OH, OW = output_shape[1], output_shape[2]
2438
2439 stride = kwargs['stride']
2440 stride_y, stride_x = stride[0], stride[1]
2441
2442 # calculate correct height, width dimensions
2443 if stride_x != 0 and stride_y != 0:
2444 y_correct = (IH + pad_top + pad_bottom + stride_y - kernel_y) // stride_y
2445 x_correct = (IW + pad_left + pad_right + stride_x - kernel_x) // stride_x
2446
2447 # ensure parameters are valid
2448 params_valid = (min(kernel) >= 1 and min(stride) >= 1 and min(pad) >= 0
2449 and not (pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]))
2450
2451 if params_valid and (OH != y_correct or OW != x_correct):
2452 error_result = True
2453
2454 info_dict = {
2455 "error_name": error_name,
2456 "error_result": error_result,
2457 "error_reason": error_reason,
2458 "param_reqs": param_reqs
2459 }
2460 return info_dict
2461
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002462 @staticmethod
2463 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
2464 error_name = ErrorIf.ArgmaxOutputShapeMismatch
2465 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2466 error_result = False
2467 error_reason = "Mismatch between output shape provided and expected output shape"
2468
2469 if check:
2470 output_shape = kwargs['output_shape']
2471 input_shape = kwargs['input_shape']
2472 axis = kwargs['axis']
2473
2474 dimension_match = True
2475 axis_shift = 0
2476
2477 # Check that rank is correct before trying to check dimensions
2478 if (len(input_shape) - 1) == len(output_shape):
2479 for i in range(len(input_shape)):
2480 if i == axis:
2481 axis_shift = 1
2482 continue
2483 if input_shape[i] != output_shape[i - axis_shift]:
2484 dimension_match = False
2485
2486 if not dimension_match:
2487 error_result = True
2488
2489 info_dict = {
2490 "error_name": error_name,
2491 "error_result": error_result,
2492 "error_reason": error_reason,
2493 "param_reqs": param_reqs
2494 }
2495 return info_dict
2496
2497 @staticmethod
2498 def evArgmaxOutputRankMismatch(check=False, **kwargs):
2499 error_name = ErrorIf.ArgmaxOutputRankMismatch
2500 param_reqs = {"rank": None, "dtype": None, "shape": None}
2501 error_result = False
2502 error_reason = "Mismatch between output shape provided and expected output shape"
2503
2504 if check:
2505 output_shape = kwargs['output_shape']
2506 input_shape = kwargs['input_shape']
2507 axis = kwargs['axis']
2508 valid_params = axis >= 0 and axis < len(input_shape)
2509
2510 if valid_params and (len(input_shape) - 1) != len(output_shape):
2511 error_result = True
2512
2513 info_dict = {
2514 "error_name": error_name,
2515 "error_result": error_result,
2516 "error_reason": error_reason,
2517 "param_reqs": param_reqs
2518 }
2519 return info_dict
2520
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002521
2522 @staticmethod
2523 def evKernelSmallerOne(check=False, **kwargs):
2524 error_name = ErrorIf.KernelSmallerOne
2525 param_reqs = {"rank": None, "dtype": None, "shape": None}
2526 error_result = False
2527 error_reason = "At least one kernel dimension is smaller than zero"
2528
2529 if check:
2530 kernel = kwargs['kernel']
2531 if min(kernel) < 1:
2532 error_result = True
2533
2534 info_dict = {
2535 "error_name": error_name,
2536 "error_result": error_result,
2537 "error_reason": error_reason,
2538 "param_reqs": param_reqs
2539 }
2540 return info_dict
2541
2542 @staticmethod
2543 def evStrideSmallerOne(check=False, **kwargs):
2544 error_name = ErrorIf.StrideSmallerOne
2545 param_reqs = {"rank": None, "dtype": None, "shape": None}
2546 error_result = False
2547 error_reason = "At least one stride dimension is smaller than zero"
2548
2549 if check:
2550 stride = kwargs['stride']
2551 if min(stride) < 1:
2552 error_result = True
2553
2554 info_dict = {
2555 "error_name": error_name,
2556 "error_result": error_result,
2557 "error_reason": error_reason,
2558 "param_reqs": param_reqs
2559 }
2560 return info_dict
2561
Matthew Haddonc2025212021-10-08 21:21:05 +01002562 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00002563 def evDilationSmallerOne(check=False, **kwargs):
2564 error_result = check and min(kwargs['dilation']) < 1
2565 return {
2566 "error_name": ErrorIf.DilationSmallerOne,
2567 "error_reason": "At least one dilation is smaller than one",
2568 "param_reqs": {"rank": None, "dtype": None, "shape": None},
2569 "error_result": error_result
2570 }
2571
2572 @staticmethod
Matthew Haddonc2025212021-10-08 21:21:05 +01002573 def evScaleTrue(check=False, **kwargs):
2574 error_name = ErrorIf.ScaleTrue
2575 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
2576 error_result = False
2577 error_reason = "Scale set to true but input type is INT48"
2578
2579 if check:
2580 input_dtype = kwargs['input_dtype']
2581 scale32 = kwargs['scale32']
2582 if scale32 and input_dtype == DType.INT48:
2583 error_result = True
2584
2585 info_dict = {
2586 "error_name": error_name,
2587 "error_result": error_result,
2588 "error_reason": error_reason,
2589 "param_reqs": param_reqs
2590 }
2591 return info_dict
2592
2593 @staticmethod
2594 def evScaleNotTrue(check=False, **kwargs):
2595 error_name = ErrorIf.ScaleNotTrue
2596 param_reqs = {"rank": None, "dtype": None, "shape": None}
2597 error_result = False
2598 error_reason = "Scale set to false but double round set to true"
2599
2600 if check:
2601 scale32 = kwargs['scale32']
2602 double_round = kwargs['double_round']
2603 if not scale32 and double_round:
2604 error_result = True
2605
2606 info_dict = {
2607 "error_name": error_name,
2608 "error_result": error_result,
2609 "error_reason": error_reason,
2610 "param_reqs": param_reqs
2611 }
2612 return info_dict
2613
Matthew Haddone807aae2021-10-11 18:12:58 +01002614 @staticmethod
2615 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
2616 error_name = ErrorIf.TensorSizeInputOutputMismatch
2617 param_reqs = {"rank": None, "dtype": None, "shape": None}
2618 error_result = False
2619 error_reason = "Input tensor size does not match output tensor size"
2620
2621 if check:
2622 input_shape = kwargs['input_shape']
2623 output_shape = kwargs['output_shape']
2624 input_size = np.prod(input_shape)
2625 output_size = np.prod(output_shape)
2626 if input_size != output_size:
2627 error_result = True
2628
2629 info_dict = {
2630 "error_name": error_name,
2631 "error_result": error_result,
2632 "error_reason": error_reason,
2633 "param_reqs": param_reqs
2634 }
2635 return info_dict
2636
2637 @staticmethod
2638 def evStartSmallerZero(check=False, **kwargs):
2639 error_name = ErrorIf.StartSmallerZero
2640 param_reqs = {"rank": None, "dtype": None, "shape": None}
2641 error_result = False
2642 error_reason = "Starting point smaller than zero"
2643
2644 if check:
2645 input_shape = kwargs['input_shape']
2646 start = kwargs['start']
2647 rank = len(input_shape)
2648 if len(start) == rank:
2649 for index in range(rank):
2650 if start[index] < 0:
2651 error_result = True
2652
2653 info_dict = {
2654 "error_name": error_name,
2655 "error_result": error_result,
2656 "error_reason": error_reason,
2657 "param_reqs": param_reqs
2658 }
2659 return info_dict
2660
2661
2662 @staticmethod
2663 def evSizeSmallerEqualZero(check=False, **kwargs):
2664 error_name = ErrorIf.SizeSmallerEqualZero
2665 param_reqs = {"rank": None, "dtype": None, "shape": None}
2666 error_result = False
2667 error_reason = "Size smaller than or equal to zero"
2668
2669 if check:
2670 input_shape = kwargs['input_shape']
2671 size = kwargs['size']
2672 rank = len(input_shape)
2673 if len(size) == rank:
2674 for index in range(rank):
2675 if size[index] <= 0:
2676 error_result = True
2677
2678 info_dict = {
2679 "error_name": error_name,
2680 "error_result": error_result,
2681 "error_reason": error_reason,
2682 "param_reqs": param_reqs
2683 }
2684 return info_dict
2685
2686
2687 @staticmethod
2688 def evStartSizeOutsideBounds(check=False, **kwargs):
2689 error_name = ErrorIf.StartSizeOutsideBounds
2690 param_reqs = {"rank": None, "dtype": None, "shape": None}
2691 error_result = False
2692 error_reason = "starting point plus size larger than input dimension"
2693
2694 if check:
2695 input_shape = kwargs['input_shape']
2696 start = kwargs['start']
2697 size = kwargs['size']
2698 rank = len(input_shape)
2699 if len(start) == rank and len(size) == rank:
2700 for index in range(rank):
2701 if start[index] + size[index] > input_shape[index]:
2702 error_result = True
2703
2704 info_dict = {
2705 "error_name": error_name,
2706 "error_result": error_result,
2707 "error_reason": error_reason,
2708 "param_reqs": param_reqs
2709 }
2710 return info_dict
2711
2712
2713 @staticmethod
2714 def evSizeOutputShapeMismatch(check=False, **kwargs):
2715 error_name = ErrorIf.SizeOutputShapeMismatch
2716 param_reqs = {"rank": None, "dtype": None, "shape": None}
2717 error_result = False
2718 error_reason = "Size does not match output dimension"
2719
2720 if check:
2721 input_shape = kwargs['input_shape']
2722 output_shape = kwargs['output_shape']
2723 size = kwargs['size']
2724 rank = len(input_shape)
2725 if len(size) == rank:
2726 for index in range(rank):
2727 if size[index] != output_shape[index]:
2728 error_result = True
2729
2730 info_dict = {
2731 "error_name": error_name,
2732 "error_result": error_result,
2733 "error_reason": error_reason,
2734 "param_reqs": param_reqs
2735 }
2736 return info_dict
2737
2738 @staticmethod
2739 def evInputSizeStartLengthMismatch(check=False, **kwargs):
2740 error_name = ErrorIf.InputSizeStartLengthMismatch
2741 param_reqs = {"rank": None, "dtype": None, "shape": None}
2742 error_result = False
2743 error_reason = "rank of input not equal to length of start or size"
2744
2745 if check:
2746 input_shape = kwargs['input_shape']
2747 start = kwargs['start']
2748 size = kwargs['size']
2749 rank = len(input_shape)
2750 if rank != len(start) or rank != len(size):
2751 error_result = True
2752
2753 info_dict = {
2754 "error_name": error_name,
2755 "error_result": error_result,
2756 "error_reason": error_reason,
2757 "param_reqs": param_reqs
2758 }
2759 return info_dict
2760
2761 @staticmethod
2762 def evIndexOutsideBounds(check=False, **kwargs):
2763 error_name = ErrorIf.IndexOutsideBounds
2764 param_reqs = {"rank": None, "dtype": None, "shape": None}
2765 error_result = False
2766 error_reason = "Index outside of allowed bounds"
2767
2768 if check:
2769 input_shape = kwargs['input_shape']
2770 perms = kwargs['perms']
2771 rank = len(input_shape)
2772
2773 for index in perms:
2774 if index < 0 or index > rank:
2775 error_result = True
2776
2777 info_dict = {
2778 "error_name": error_name,
2779 "error_result": error_result,
2780 "error_reason": error_reason,
2781 "param_reqs": param_reqs
2782 }
2783 return info_dict
2784
2785 @staticmethod
2786 def evIndexUsedTwice(check=False, **kwargs):
2787 error_name = ErrorIf.IndexUsedTwice
2788 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2789 error_result = False
2790 error_reason = "Index used multiple times"
2791
2792 if check:
2793 input_shape = kwargs['input_shape']
2794 perms = kwargs['perms']
2795 rank = len(input_shape)
2796
2797 unique_indices = []
2798 for index in perms:
2799 if index in unique_indices:
2800 error_result = True
2801 else:
2802 unique_indices.append(index)
2803
2804 info_dict = {
2805 "error_name": error_name,
2806 "error_result": error_result,
2807 "error_reason": error_reason,
2808 "param_reqs": param_reqs
2809 }
2810 return info_dict
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002811
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002812 @staticmethod
2813 def evMaxSmallerMin(check=False, **kwargs):
2814 error_name = ErrorIf.MaxSmallerMin
2815 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2816 error_result = False
2817 error_reason = "Max value smaller than min value"
2818
2819 if check:
2820 max_val = kwargs['max_val']
2821 min_val = kwargs['min_val']
2822 if max_val < min_val:
2823 error_result = True
2824
2825
2826 info_dict = {
2827 "error_name": error_name,
2828 "error_result": error_result,
2829 "error_reason": error_reason,
2830 "param_reqs": param_reqs
2831 }
2832 return info_dict
2833
2834 @staticmethod
2835 def evConcatInputRankMismatch(check=False, **kwargs):
2836 error_name = ErrorIf.ConcatInputRankMismatch
2837 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2838 error_result = False
2839 error_reason = "Input ranks are not identical"
2840
2841 if check:
2842 inputs = kwargs['inputs']
2843 input_shape = kwargs['input_shape']
2844 for input in inputs:
2845 if len(input.shape) != len(input_shape):
2846 error_result = True
2847
2848 info_dict = {
2849 "error_name": error_name,
2850 "error_result": error_result,
2851 "error_reason": error_reason,
2852 "param_reqs": param_reqs
2853 }
2854 return info_dict
2855
2856 @staticmethod
2857 def evConcatInputDimMismatch(check=False, **kwargs):
2858 error_name = ErrorIf.ConcatInputDimMismatch
2859 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2860 error_result = False
2861 error_reason = "Input dimensions differ on too many axes"
2862
2863 if check:
2864 inputs = kwargs['inputs']
2865 input_shape = kwargs['input_shape']
2866 axis = kwargs['axis']
2867
2868 # Ensure rank is valid before checking dims.
2869 valid_rank = True
2870 for input in inputs:
2871 if len(input.shape) != len(input_shape):
2872 valid_rank = False
2873
2874 if valid_rank:
2875 for input in inputs:
2876 for i, dim in enumerate(input.shape):
2877 if dim != input_shape[i] and axis != i:
2878 error_result = True
2879
2880 info_dict = {
2881 "error_name": error_name,
2882 "error_result": error_result,
2883 "error_reason": error_reason,
2884 "param_reqs": param_reqs
2885 }
2886 return info_dict
2887
Matthew Haddon630c17c2021-10-14 15:05:41 +01002888 @staticmethod
Matthew Haddon01c359d2021-10-15 16:30:48 +01002889 def evConcatShapeSumMismatch(check=False, **kwargs):
2890 error_name = ErrorIf.ConcatShapeSumMismatch
2891 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2892 error_result = False
2893 error_reason = "Sum of dimensions on axis not equal to output dimension"
2894
2895 if check:
2896 inputs = kwargs['inputs']
2897 input_shape = kwargs['input_shape']
2898 output_shape = kwargs['output_shape']
2899 axis = kwargs['axis']
2900
2901 # Ensure rank is valid before checking dims.
2902 valid_params = True
2903 for input in inputs:
2904 if len(input.shape) != len(input_shape):
2905 valid_params = False
2906 if axis < 0 or axis > len(input_shape):
2907 valid_params = False
2908
2909 if valid_params:
2910 axis_dim_sum = 0
2911 for input in inputs:
2912 axis_dim_sum += input.shape[axis]
2913
2914 if axis_dim_sum != output_shape[axis]:
2915 error_result = True
2916
2917
2918 info_dict = {
2919 "error_name": error_name,
2920 "error_result": error_result,
2921 "error_reason": error_reason,
2922 "param_reqs": param_reqs
2923 }
2924 return info_dict
2925
2926 @staticmethod
Matthew Haddon630c17c2021-10-14 15:05:41 +01002927 def evInputListThenGraphMismatch(check=False, **kwargs):
2928 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2929 param_reqs = {"rank": None, "dtype": None, "shape": None}
2930 error_result = False
2931 error_reason = "Input list shape does not match then-graph shape"
2932
2933 if check:
2934 a = kwargs['a']
2935 b = kwargs['b']
2936 basicBlocks = kwargs['basicBlocks']
2937 then_block = basicBlocks[1]
2938 then_inputs = then_block.inputs
2939 then_tens = then_block.tensors
2940 if (a.shape != then_tens[then_inputs[0]].shape) or (b.shape != then_tens[then_inputs[1]].shape):
2941 error_result = True
2942
2943 info_dict = {
2944 "error_name": error_name,
2945 "error_result": error_result,
2946 "error_reason": error_reason,
2947 "param_reqs": param_reqs
2948 }
2949 return info_dict
2950
2951
2952 @staticmethod
2953 def evInputListElseGraphMismatch(check=False, **kwargs):
2954 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2955 param_reqs = {"rank": None, "dtype": None, "shape": None}
2956 error_result = False
2957 error_reason = "Input list shape does not match else-graph shape"
2958
2959 if check:
2960 a = kwargs['a']
2961 b = kwargs['b']
2962 basicBlocks = kwargs['basicBlocks']
2963 else_block = basicBlocks[2]
2964 else_inputs = else_block.inputs
2965 else_tens = else_block.tensors
2966 if (a.shape != else_tens[else_inputs[0]].shape) or (b.shape != else_tens[else_inputs[1]].shape):
2967 error_result = True
2968
2969 info_dict = {
2970 "error_name": error_name,
2971 "error_result": error_result,
2972 "error_reason": error_reason,
2973 "param_reqs": param_reqs
2974 }
2975 return info_dict
2976
2977
2978 @staticmethod
2979 def evOutputListThenGraphMismatch(check=False, **kwargs):
2980 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2981 param_reqs = {"rank": None, "dtype": None, "shape": None}
2982 error_result = False
2983 error_reason = "Output list shape does not match then-graph shape"
2984
2985 if check:
2986 basicBlocks = kwargs['basicBlocks']
2987 cond_block = basicBlocks[0]
2988 cond_outputs = cond_block.outputs
2989 cond_tens = cond_block.tensors
2990 then_block = basicBlocks[1]
2991 then_outputs = then_block.outputs
2992 then_tens = then_block.tensors
2993 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2994 error_result = True
2995
2996 info_dict = {
2997 "error_name": error_name,
2998 "error_result": error_result,
2999 "error_reason": error_reason,
3000 "param_reqs": param_reqs
3001 }
3002 return info_dict
3003
3004
3005 @staticmethod
3006 def evOutputListElseGraphMismatch(check=False, **kwargs):
3007 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
3008 param_reqs = {"rank": None, "dtype": None, "shape": None}
3009 error_result = False
3010 error_reason = "Output list shape does not match else-graph shape"
3011
3012 if check:
3013 basicBlocks = kwargs['basicBlocks']
3014 cond_block = basicBlocks[0]
3015 cond_outputs = cond_block.outputs
3016 cond_tens = cond_block.tensors
3017 else_block = basicBlocks[2]
3018 else_outputs = else_block.outputs
3019 else_tens = else_block.tensors
3020 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
3021 error_result = True
3022
3023 info_dict = {
3024 "error_name": error_name,
3025 "error_result": error_result,
3026 "error_reason": error_reason,
3027 "param_reqs": param_reqs
3028 }
3029 return info_dict
3030
3031
3032 @staticmethod
3033 def evInputListOutputListMismatch(check=False, **kwargs):
3034 error_name = ErrorIf.InputListOutputListMismatch
3035 param_reqs = {"rank": None, "dtype": None, "shape": None}
3036 error_result = False
3037 error_reason = "Input list does not match output list"
3038
3039 if check:
3040 basicBlocks = kwargs['basicBlocks']
3041 while_block = basicBlocks[0]
3042 while_inputs = while_block.inputs
3043 while_outputs = while_block.outputs
3044 while_tens = while_block.tensors
3045 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
3046 error_result = True
3047
3048 info_dict = {
3049 "error_name": error_name,
3050 "error_result": error_result,
3051 "error_reason": error_reason,
3052 "param_reqs": param_reqs
3053 }
3054 return info_dict
3055
3056
3057 @staticmethod
3058 def evInputListCondGraphMismatch(check=False, **kwargs):
3059 error_name = ErrorIf.InputListCondGraphMismatch
3060 param_reqs = {"rank": None, "dtype": None, "shape": None}
3061 error_result = False
3062 error_reason = "Input list does not match cond graph"
3063
3064 if check:
3065 basicBlocks = kwargs['basicBlocks']
3066 while_block = basicBlocks[0]
3067 while_inputs = while_block.inputs
3068 while_tens = while_block.tensors
3069 cond_block = basicBlocks[1]
3070 cond_inputs = cond_block.inputs
3071 cond_tens = cond_block.tensors
3072 if ((while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape) or
3073 (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape)):
3074 error_result = True
3075
3076 info_dict = {
3077 "error_name": error_name,
3078 "error_result": error_result,
3079 "error_reason": error_reason,
3080 "param_reqs": param_reqs
3081 }
3082 return info_dict
3083
3084
3085 @staticmethod
3086 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
3087 error_name = ErrorIf.InputListBodyGraphInputMismatch
3088 param_reqs = {"rank": None, "dtype": None, "shape": None}
3089 error_result = False
3090 error_reason = "Input list does not match body graph input"
3091
3092 if check:
3093 basicBlocks = kwargs['basicBlocks']
3094 while_block = basicBlocks[0]
3095 while_inputs = while_block.inputs
3096 while_tens = while_block.tensors
3097 body_block = basicBlocks[2]
3098 body_outputs = body_block.inputs
3099 body_tens = body_block.tensors
3100 if ((while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape) or
3101 (while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape)):
3102 error_result = True
3103
3104 info_dict = {
3105 "error_name": error_name,
3106 "error_result": error_result,
3107 "error_reason": error_reason,
3108 "param_reqs": param_reqs
3109 }
3110 return info_dict
3111
3112
3113 @staticmethod
3114 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
3115 error_name = ErrorIf.InputListBodyGraphOutputMismatch
3116 param_reqs = {"rank": None, "dtype": None, "shape": None}
3117 error_result = False
3118 error_reason = "Input list does not match body graph output"
3119
3120 if check:
3121 basicBlocks = kwargs['basicBlocks']
3122 while_block = basicBlocks[0]
3123 while_inputs = while_block.inputs
3124 while_tens = while_block.tensors
3125 body_block = basicBlocks[2]
3126 body_outputs = body_block.outputs
3127 body_tens = body_block.tensors
3128 if ((while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape) or
3129 (while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape)):
3130 error_result = True
3131 info_dict = {
3132 "error_name": error_name,
3133 "error_result": error_result,
3134 "error_reason": error_reason,
3135 "param_reqs": param_reqs
3136 }
3137 return info_dict
3138
3139
3140 @staticmethod
3141 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
3142 error_name = ErrorIf.CondGraphOutputNotMatchingBool
3143 param_reqs = {"rank": None, "dtype": None, "shape": None}
3144 error_result = False
3145 error_reason = "Cond graph output is not a match list of booleans"
3146
3147 if check:
3148 basicBlocks = kwargs['basicBlocks']
3149 cond_block = basicBlocks[1]
3150 cond_outputs = cond_block.outputs
3151 cond_tens = cond_block.tensors
3152 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
3153 error_result = True
3154
3155 info_dict = {
3156 "error_name": error_name,
3157 "error_result": error_result,
3158 "error_reason": error_reason,
3159 "param_reqs": param_reqs
3160 }
3161 return info_dict
3162
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003163
Matthew Haddonb724efc2021-08-25 16:40:29 +01003164class TosaInvalidValidator:
3165
3166 @staticmethod
3167 def ivWrongDataTypeOrModeResize(**kwargs):
3168 input_dtype = kwargs["input_dtype"]
3169 args = kwargs["args"]
3170 mode = args[0]
3171 stride = args[1]
3172 stride_fp = args[4]
3173 output_dtype = args[8]
3174
3175 if mode == ResizeMode.BILINEAR:
3176 # Invalid output data type / Invalid input datatype
3177 return (
3178 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
3179 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
3180 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
3181 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
3182 )
3183 elif mode == ResizeMode.NEAREST:
3184 # Invalid output data type / Invalid input datatype
3185 return (
3186 (input_dtype != output_dtype) or
3187 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
3188 )
3189 else:
3190 # Invalid resize mode
3191 return True
3192
3193 @staticmethod
3194 def ivBadStride(**kwargs):
3195 input_dtype = kwargs["input_dtype"]
3196 args = kwargs["args"]
3197 stride_x = args[1][0]
3198 stride_y = args[1][1]
3199 stride_fp_x = args[4][0]
3200 stride_fp_y = args[4][1]
3201
3202 if input_dtype == DType.FLOAT:
3203 if stride_fp_x <= 0 or stride_fp_y <= 0:
3204 # Negative or zero stride
3205 return True
3206 else:
3207 if stride_x <= 0 or stride_y <= 0:
3208 # Negative or zero stride
3209 return True
3210 return False
3211
Matthew Haddonb724efc2021-08-25 16:40:29 +01003212 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00003213 def ivHeightWidthInvalid(**kwargs):
Matthew Haddonb724efc2021-08-25 16:40:29 +01003214 opName = kwargs['opName']
3215
3216 inputShapes = kwargs['shapeList']
Les Bell0e027d42021-11-09 14:42:14 +00003217 input_shape = inputShapes[0]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003218
3219 args = kwargs['args']
3220 strides = args[0]
3221 padding = args[1]
Les Bell0e027d42021-11-09 14:42:14 +00003222
Matthew Haddonb724efc2021-08-25 16:40:29 +01003223 if opName.endswith("pool2d"):
Les Bell0e027d42021-11-09 14:42:14 +00003224 # avg_pool2d, max_pool2d
3225 kernel_shape = args[2]
3226 h = (input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]) // strides[0]
3227 w = (input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]) // strides[1]
3228 # return True if any dimension is < 1
3229 return h < 1 or w < 1
Matthew Haddonb724efc2021-08-25 16:40:29 +01003230
Les Bell0e027d42021-11-09 14:42:14 +00003231 if opName.startswith("transpose_conv2d"):
3232 # transpose_conv2d
3233 dilations = args[2]
3234 output_shape = args[3]
3235 filter_shape = inputShapes[1]
3236 kernel_shape = filter_shape[1:-1]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003237
Les Bell0e027d42021-11-09 14:42:14 +00003238 def get_out_size(in_size, stride, kernel_size, dilation, out_pad, in_pad):
3239 """Calculate the transpose_conv2d output size for a dimension.
Matthew Haddonb724efc2021-08-25 16:40:29 +01003240
Les Bell0e027d42021-11-09 14:42:14 +00003241 Based on the keras function deconv_output_length, in
3242 https://github.com/keras-team/keras/blob/master/keras/utils/conv_utils.py
Matthew Haddonb724efc2021-08-25 16:40:29 +01003243
Les Bell0e027d42021-11-09 14:42:14 +00003244 Args:
3245 in_size: the input size - int
3246 stride: the stride - int
3247 kernel_size: the kernel size - int
3248 dilation: the kernel dilation - int
3249 out_pad: the output padding - int
3250 in_pad: the input padding - int
3251
3252 Returns:
3253 the output size
3254 """
3255 dilated_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
3256 return (in_size - 1) * stride + dilated_kernel_size - 2 * in_pad + out_pad
3257
3258 for pad_h, pad_w in (
3259 (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
3260 (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
3261 (0, 0) # VALID padding
3262 ):
3263 h = get_out_size(input_shape[1], strides[0], kernel_shape[0], dilations[0],
3264 padding[0], pad_h)
3265 w = get_out_size(input_shape[2], strides[1], kernel_shape[1], dilations[1],
3266 padding[1], pad_w)
3267 if output_shape[1] == h and output_shape[2] == w:
3268 return False
3269
3270 # output shape does not match the expected shape for any padding option
Matthew Haddonb724efc2021-08-25 16:40:29 +01003271 return True
Les Bell0e027d42021-11-09 14:42:14 +00003272
3273 if "conv2d" in opName or "conv3d" in opName:
3274 # conv2d, conv3d, depthwise_conv2d
3275 dilations = args[2]
3276 filter_shape = inputShapes[1]
3277 kernel_shape = filter_shape[0:2] if opName.startswith("depthwise_conv2d") else filter_shape[1:-1]
3278
3279 for i in range(len(kernel_shape)):
3280 dim = (
3281 input_shape[i + 1]
3282 - kernel_shape[i]
3283 - (kernel_shape[i] - 1) * (dilations[i] - 1)
3284 + padding[i * 2 + 0]
3285 + padding[i * 2 + 1]
3286 ) // strides[i] + 1
3287 # return True if any dimension is < 1
3288 if dim < 1:
3289 return True
3290 return False
3291
3292 assert False, f"Unrecognized Op: {opName}"
Matthew Haddonb724efc2021-08-25 16:40:29 +01003293
3294 @staticmethod
3295 def ivNonPositiveOutputShape(**kwargs):
3296 args = kwargs['args']
3297 output_shape = args[3]
3298 if output_shape[1] <= 0 or output_shape[2] <= 0:
3299 # Negative output shape
3300 return True
3301 return False
3302
3303
Eric Kunzee5e26762020-10-13 16:11:07 -07003304class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003305 # Maximum rank of tensor supported by test generator.
3306 TOSA_TENSOR_MAX_RANK = 6
3307
Eric Kunzee5e26762020-10-13 16:11:07 -07003308 def __init__(self, args):
3309 self.args = args
3310 self.basePath = args.output_dir
3311 self.random_seed = args.random_seed
3312 self.ser = None
3313 self.rng = np.random.default_rng(self.random_seed)
3314 self.createDynamicOpLists()
3315 self.initOpListDefaults()
3316 self.quantGen = TosaQuantGen()
3317 # Force makeShape to do a specific starting shape
3318 self.targetted_shape = None
3319
3320 def createSerializer(self, opName, testPath):
3321 self.testPath = os.path.join(opName, testPath)
3322
3323 fullPath = os.path.join(self.basePath, self.testPath)
3324 os.makedirs(fullPath, exist_ok=True)
3325 self.ser = ts.TosaSerializer(fullPath)
3326
3327 def getSerializer(self):
3328 return self.ser
3329
3330 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003331 with open(
3332 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
3333 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07003334 fd.write(self.ser.serialize())
3335
Kevin Cheng550ccc52021-03-03 11:21:43 -08003336 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
3337 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07003338
Matthew Haddon74567092021-07-16 15:38:20 +01003339 def resetRNG(self, seed=None):
3340 if seed == None:
3341 seed = self.random_seed + 1
3342 self.rng = np.random.default_rng(seed)
3343
Eric Kunzee5e26762020-10-13 16:11:07 -07003344 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07003345 if dtype == DType.BOOL:
3346 np_dt = np.bool
3347 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07003348 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07003349 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07003350 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003351 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01003352 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
3353 elif dtype == DType.UINT8:
3354 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003355 elif dtype == DType.INT16:
3356 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
3357 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003358 return np.int32(
3359 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
3360 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003361 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003362 return np.int64(
3363 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
3364 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003365 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01003366 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003367 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003368 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003369
Kevin Cheng989cb052021-04-28 16:29:44 -07003370 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07003371 placeholders = []
3372
Kevin Cheng989cb052021-04-28 16:29:44 -07003373 assert len(shape_list) == len(dtype_list)
3374
3375 for idx, shape in enumerate(shape_list):
3376 arr = self.getRandTensor(shape, dtype_list[idx])
3377 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07003378
3379 return placeholders
3380
Kevin Cheng989cb052021-04-28 16:29:44 -07003381 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07003382 consts = []
3383
Kevin Cheng989cb052021-04-28 16:29:44 -07003384 assert len(shape_list) == len(dtype_list)
3385
3386 for idx, shape in enumerate(shape_list):
3387 arr = self.getRandTensor(shape, dtype_list[idx])
3388 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07003389
3390 return consts
3391
3392 def makeShape(self, rank):
3393 if self.targetted_shape:
3394 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003395 return np.int32(
3396 self.rng.integers(
3397 low=self.args.tensor_shape_range[0],
3398 high=self.args.tensor_shape_range[1],
3399 size=rank,
3400 )
3401 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003402
3403 def setTargetShape(self, shape):
3404 self.targetted_shape = shape
3405
3406 def randInt(self, low=0, high=256):
3407 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
3408
3409 def getRandNumberDType(self, dtype):
3410 if dtype == DType.FLOAT:
3411 return self.rng.random()
3412 elif dtype == DType.BOOL:
3413 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07003414 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07003415 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07003416 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07003417 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01003418 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07003419 elif dtype == DType.INT16:
3420 low, high = (-32768, 32768)
3421 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003422 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07003423 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003424 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07003425 # Special size
3426 return np.int64(self.rng.integers(low, high, size=1))[0]
3427 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003428 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003429
3430 return np.int32(self.rng.integers(low, high, size=1))[0]
3431
3432 def shapeStr(self, shape):
3433
3434 sStr = []
3435 # Convert to strings
3436 for i in shape:
3437 sStr.append(str(i))
3438
Kevin Cheng550ccc52021-03-03 11:21:43 -08003439 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003440
3441 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07003442 if isinstance(t, list):
3443 assert len(t) >= 2
3444 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07003445 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07003446 if t == DType.BOOL:
3447 return "b"
3448 elif t == DType.INT4:
3449 return "i4"
3450 elif t == DType.INT8:
3451 return "i8"
3452 elif t == DType.UINT8:
3453 return "u8"
3454 elif t == DType.INT16:
3455 return "i16"
3456 elif t == DType.INT32:
3457 return "i32"
3458 elif t == DType.INT48:
3459 return "i48"
3460 elif t == DType.FLOAT:
3461 return "float"
3462 else:
3463 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07003464
3465 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003466 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08003467 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07003468 return 4
3469 elif t == DType.INT8:
3470 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08003471 elif t == DType.UINT8:
3472 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07003473 elif t == DType.INT16:
3474 return 16
3475 elif t == DType.INT32:
3476 return 32
3477 elif t == DType.INT48:
3478 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +01003479 elif t == DType.FLOAT:
3480 return 32
3481 elif t == DType.BOOL:
3482 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003483 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003484 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07003485
3486 # Argument generators
3487 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
3488 # Where the string descriptor is used to generate the test name and
3489 # The build_fcn_arg_list is expanded and passed to the operator test
3490 # build function
3491
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003492 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
3493 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
3494
Matthew Haddon848efb42021-09-09 12:30:53 +01003495 # build_placeholder returns an int, ABS/other ops does not
3496 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003497 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
3498 return result_tens
3499 elif op['op'] == Op.IDENTITY:
3500 self.ser.addOperator(op['op'], a.name, result_tens.name, None, qinfo)
3501 return result_tens
3502
3503 # Ensure new output type has correct qinfo
3504 if error_name == ErrorIf.WrongOutputType:
3505 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
3506 qinfo = ts.TosaSerializerQuantInfo()
3507 qinfo.UnaryQuantInfo(
3508 TosaQuantGen.getQinfo(self, a.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
3509 )
3510
3511 # Invalidate Input/Output list for error if checks.
3512 input_list = [a.name]
3513 output_list = [result_tens.name]
3514 pCount, cCount = op["operands"]
3515 num_operands = pCount + cCount
3516 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3517
3518 TosaErrorValidator.evValidateErrorIfs(
3519 self.ser,
3520 validator_fcns,
3521 error_name,
3522 op=op,
3523 input_dtype=a.dtype,
3524 output_dtype=result_tens.dtype,
3525 qinfo = qinfo,
3526 result_tensor = result_tens,
3527 input_list=input_list,
3528 output_list=output_list,
3529 num_operands=num_operands,
3530 )
3531
3532 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003533 return result_tens
3534
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003535 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
3536 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
3537
3538
3539 # Invalidate Input/Output list for error if checks.
3540 input_list = [a.name, b.name]
3541 output_list = [result_tens.name]
3542 pCount, cCount = op["operands"]
3543 num_operands = pCount + cCount
3544 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3545
3546 TosaErrorValidator.evValidateErrorIfs(
3547 self.ser,
3548 validator_fcns,
3549 error_name,
3550 op=op,
3551 input1 = a,
3552 input2 = b,
3553 input_dtype = a.dtype,
3554 output_dtype = result_tens.dtype,
3555 result_tensor = result_tens,
3556 input_list=input_list,
3557 output_list=output_list,
3558 num_operands=num_operands,
3559 )
3560
3561 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07003562 return result_tens
3563
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003564 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003565 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01003566 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003567 return result_tens
3568
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003569 def build_arithmetic_right_shift(self, op, a, b, round, validator_fcns=None, error_name=None):
3570 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
3571
3572 # Invalidate Input/Output list for error if checks.
3573 input_list = [a.name, b.name]
3574 output_list = [result_tens.name]
3575 pCount, cCount = op["operands"]
3576 num_operands = pCount + cCount
3577 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3578
3579 TosaErrorValidator.evValidateErrorIfs(
3580 self.ser,
3581 validator_fcns,
3582 error_name,
3583 op=op,
3584 input1 = a,
3585 input2 = b,
3586 input_dtype = a.dtype,
3587 output_dtype = result_tens.dtype,
3588 result_tensor = result_tens,
3589 input_list=input_list,
3590 output_list=output_list,
3591 num_operands=num_operands,
3592 )
Kevin Chengaee1fac2020-11-11 13:54:06 -08003593
3594 attr = ts.TosaSerializerAttribute()
3595 attr.ArithmeticRightShiftAttribute(round)
3596
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003597 self.ser.addOperator(op['op'], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08003598 return result_tens
3599
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003600 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
3601 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003602
3603 # Special for multiply:
3604 # Force the result to INT32 for INT types
3605 if a.dtype != DType.FLOAT:
3606 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003607 if error_name == ErrorIf.WrongOutputType:
3608 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
3609 outputDType = self.rng.choice(all_dtypes)
3610 result_tens.setDtype(outputDType)
3611
3612 # Invalidate Input/Output list for error if checks.
3613 input_list = [a.name, b.name]
3614 output_list = [result_tens.name]
3615 pCount, cCount = op["operands"]
3616 num_operands = pCount + cCount
3617 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3618
3619 TosaErrorValidator.evValidateErrorIfs(
3620 self.ser,
3621 validator_fcns,
3622 error_name,
3623 op=op,
3624 input1 = a,
3625 input2 = b,
3626 input_dtype = a.dtype,
3627 output_dtype = result_tens.dtype,
3628 result_tensor = result_tens,
3629 input_list=input_list,
3630 output_list=output_list,
3631 num_operands=num_operands,
3632 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003633
Kevin Chengaee1fac2020-11-11 13:54:06 -08003634 attr = ts.TosaSerializerAttribute()
3635 attr.MulAttribute(shift)
3636
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003637 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003638 return result_tens
3639
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003640 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
3641 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003642
Kevin Chengfe392ce2021-10-18 21:51:55 +00003643 attr = ts.TosaSerializerAttribute()
3644 attr.TableAttribute(table)
3645
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003646 # Invalidate Input/Output list for error if checks.
3647 input_list = [a.name]
3648 output_list = [result_tens.name]
3649 pCount, cCount = op["operands"]
3650 num_operands = pCount + cCount
3651 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3652
3653 TosaErrorValidator.evValidateErrorIfs(
3654 self.ser,
3655 validator_fcns,
3656 error_name,
3657 op=op,
3658 input_shape = a.shape,
3659 input_dtype = a.dtype,
3660 output_dtype = result_tens.dtype,
3661 result_tensor = result_tens,
3662 input_list=input_list,
3663 output_list=output_list,
3664 num_operands=num_operands,
3665 )
3666
3667 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003668
3669 return result_tens
3670
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003671 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
3672 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
3673
3674 # Invalidate Input/Output list for error if checks.
3675 input_list = [cond.name, a.name, b.name]
3676 output_list = [result_tens.name]
3677 pCount, cCount = op["operands"]
3678 num_operands = pCount + cCount
3679 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3680
3681 TosaErrorValidator.evValidateErrorIfs(
3682 self.ser,
3683 validator_fcns,
3684 error_name,
3685 op=op,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003686 input1 = cond,
3687 input2 = a,
3688 input3 = b,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003689 input_shape = a.shape,
3690 input_dtype = a.dtype,
3691 output_dtype = result_tens.dtype,
3692 result_tensor = result_tens,
3693 input_list=input_list,
3694 output_list=output_list,
3695 num_operands=num_operands,
3696 )
3697
3698 self.ser.addOperator(op['op'], input_list, output_list,)
Eric Kunzee5e26762020-10-13 16:11:07 -07003699 return result_tens
3700
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003701 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
3702 result_tens = OutputShaper.binaryComparisonOp(self.ser, self.rng, a, b, error_name)
3703
3704 # Invalidate Input/Output list for error if checks.
3705 input_list = [a.name, b.name]
3706 output_list = [result_tens.name]
3707 pCount, cCount = op["operands"]
3708 num_operands = pCount + cCount
3709 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3710
3711 TosaErrorValidator.evValidateErrorIfs(
3712 self.ser,
3713 validator_fcns,
3714 error_name,
3715 op=op,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003716 input1 = a,
3717 input2 = b,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003718 input_shape = a.shape,
3719 input_dtype = a.dtype,
3720 output_shape = result_tens.shape,
3721 output_dtype = result_tens.dtype,
3722 result_tensor = result_tens,
3723 input_list=input_list,
3724 output_list=output_list,
3725 num_operands=num_operands,
3726 )
3727
3728 self.ser.addOperator(op['op'], input_list, output_list,)
Eric Kunzee5e26762020-10-13 16:11:07 -07003729 return result_tens
3730
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003731 def build_argmax(self, op, a, axis, validator_fcns, error_name):
3732 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
3733
3734 # Invalidate Input/Output list for error if checks.
3735 input_list = [a.name]
3736 output_list = [result_tens.name]
3737 pCount, cCount = op["operands"]
3738 num_operands = pCount + cCount
3739 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3740
3741 TosaErrorValidator.evValidateErrorIfs(
3742 self.ser,
3743 validator_fcns,
3744 error_name,
3745 op=op,
3746 axis=axis,
3747 input_shape = a.shape,
3748 input_dtype = a.dtype,
3749 output_shape = result_tens.shape,
3750 output_dtype = result_tens.dtype,
3751 result_tensor = result_tens,
3752 input_list=input_list,
3753 output_list=output_list,
3754 num_operands=num_operands,
3755 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003756
3757 attr = ts.TosaSerializerAttribute()
3758 attr.AxisAttribute(axis)
3759
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003760 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003761 return result_tens
3762
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003763 def build_pool2d(self, op, input, stride, pad, kernel, validator_fcns=None, error_name=None, qinfo=None):
3764 result_tens = OutputShaper.pool2dOp(self.ser, self.rng, input, kernel, stride, pad, error_name)
3765
3766 # Ensure new output type has correct qinfo
3767 if error_name == ErrorIf.WrongInputType:
3768 if input.dtype not in [DType.INT8, DType.UINT8]:
3769 qinfo = ts.TosaSerializerQuantInfo()
3770 qinfo.UnaryQuantInfo(
Les Bell0e027d42021-11-09 14:42:14 +00003771 TosaQuantGen.getQinfo(self, input.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003772 )
3773
3774 # Invalidate Input/Output list for error if checks.
3775 input_list = [input.name]
3776 output_list = [result_tens.name]
3777 pCount, cCount = op["operands"]
3778 num_operands = pCount + cCount
3779 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3780
3781 TosaErrorValidator.evValidateErrorIfs(
3782 self.ser,
3783 validator_fcns,
3784 error_name,
3785 op=op,
3786 input_shape=input.shape,
3787 input_dtype=input.dtype,
3788 output_shape=result_tens.shape,
3789 output_dtype=result_tens.dtype,
3790 kernel=kernel,
3791 stride=stride,
3792 pad=pad,
3793 qinfo = qinfo,
3794 result_tensor = result_tens,
3795 input_list=input_list,
3796 output_list=output_list,
3797 num_operands=num_operands,
3798 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003799
3800 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003801 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07003802
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003803 self.ser.addOperator(op['op'], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003804 return result_tens
3805
Les Bell0e027d42021-11-09 14:42:14 +00003806 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003807 assert len(padding) == 4
3808 result_tens = OutputShaper.conv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +00003809 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
3810 )
3811
3812 # Ensure new output type has correct qinfo
3813 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
3814 qinfo = ts.TosaSerializerQuantInfo()
3815 qinfo.ConvQuantInfo(
3816 TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
3817 )
3818
3819 # Invalidate Input/Output list for error_if checks.
3820 input_list = [ifm.name, filter.name, bias.name]
3821 output_list = [result_tens.name]
3822 num_operands = sum(op["operands"])
3823 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3824
3825 TosaErrorValidator.evValidateErrorIfs(
3826 self.ser,
3827 validator_fcns,
3828 error_name,
3829 op=op,
3830 input_dtype=ifm.dtype,
3831 weight_dtype=filter.dtype,
3832 output_dtype=result_tens.dtype,
3833 qinfo=qinfo,
3834 input_list=input_list,
3835 num_operands=num_operands,
3836 output_list=output_list,
3837 pad=padding,
3838 stride=strides,
3839 dilation=dilations,
3840 input_shape=ifm.shape,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003841 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003842
3843 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003844 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07003845
Kevin Cheng550ccc52021-03-03 11:21:43 -08003846 self.ser.addOperator(
Les Bell0e027d42021-11-09 14:42:14 +00003847 op['op'], input_list, output_list, attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003848 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003849 return result_tens
3850
Les Bell0e027d42021-11-09 14:42:14 +00003851 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None):
Kevin Cheng1533b852021-09-01 12:51:58 -07003852 assert len(padding) == 6
3853 result_tens = OutputShaper.conv3dOp(
Les Bell0e027d42021-11-09 14:42:14 +00003854 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
3855 )
3856
3857 # Ensure new output type has correct qinfo
3858 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
3859 qinfo = ts.TosaSerializerQuantInfo()
3860 qinfo.ConvQuantInfo(
3861 TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
3862 )
3863
3864 # Invalidate Input/Output list for error_if checks.
3865 input_list = [ifm.name, filter.name, bias.name]
3866 output_list = [result_tens.name]
3867 num_operands = sum(op["operands"])
3868 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3869
3870 TosaErrorValidator.evValidateErrorIfs(
3871 self.ser,
3872 validator_fcns,
3873 error_name,
3874 op=op,
3875 input_dtype=ifm.dtype,
3876 weight_dtype=filter.dtype,
3877 output_dtype=result_tens.dtype,
3878 qinfo=qinfo,
3879 input_list=input_list,
3880 num_operands=num_operands,
3881 output_list=output_list,
3882 pad=padding,
3883 stride=strides,
3884 dilation=dilations,
3885 input_shape=ifm.shape,
Kevin Cheng1533b852021-09-01 12:51:58 -07003886 )
3887
3888 attr = ts.TosaSerializerAttribute()
3889 attr.ConvAttribute(padding, strides, dilations)
3890
3891 self.ser.addOperator(
Les Bell0e027d42021-11-09 14:42:14 +00003892 op['op'], input_list, output_list, attr, qinfo
Kevin Cheng1533b852021-09-01 12:51:58 -07003893 )
3894 return result_tens
3895
Kevin Cheng550ccc52021-03-03 11:21:43 -08003896 def build_transpose_conv2d(
Les Bell0e027d42021-11-09 14:42:14 +00003897 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, validator_fcns=None, error_name=None, qinfo=None
Kevin Cheng550ccc52021-03-03 11:21:43 -08003898 ):
3899 assert len(outpad) == 2
Les Bell0e027d42021-11-09 14:42:14 +00003900 result_tens = OutputShaper.transposeConv2DOp(self.ser, self.rng, ifm, output_shape, error_name)
3901
3902 # Ensure new output type has correct qinfo
3903 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
3904 qinfo = ts.TosaSerializerQuantInfo()
3905 qinfo.ConvQuantInfo(
3906 TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
3907 )
3908
3909 # Invalidate Input/Output list for error_if checks.
3910 input_list = [ifm.name, filter.name, bias.name]
3911 output_list = [result_tens.name]
3912 num_operands = sum(op["operands"])
3913 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3914
3915 TosaErrorValidator.evValidateErrorIfs(
3916 self.ser,
3917 validator_fcns,
3918 error_name,
3919 op=op,
3920 input_dtype=ifm.dtype,
3921 weight_dtype=filter.dtype,
3922 output_dtype=result_tens.dtype,
3923 qinfo=qinfo,
3924 input_list=input_list,
3925 num_operands=num_operands,
3926 output_list=output_list,
3927 pad=outpad,
3928 stride=stride,
3929 dilation=dilation,
3930 input_shape=ifm.shape,
3931 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003932
3933 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003934 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003935
Kevin Cheng550ccc52021-03-03 11:21:43 -08003936 self.ser.addOperator(
Les Bell0e027d42021-11-09 14:42:14 +00003937 op['op'], input_list, output_list, attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003938 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003939 return result_tens
3940
Kevin Cheng550ccc52021-03-03 11:21:43 -08003941 def build_depthwise_conv2d(
Les Bell0e027d42021-11-09 14:42:14 +00003942 self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None
Kevin Cheng550ccc52021-03-03 11:21:43 -08003943 ):
3944 result_tens = OutputShaper.depthwiseConv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +00003945 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
3946 )
3947
3948 # Ensure new output type has correct qinfo
3949 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
3950 qinfo = ts.TosaSerializerQuantInfo()
3951 qinfo.ConvQuantInfo(
3952 TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
3953 )
3954
3955 # Invalidate Input/Output list for error_if checks.
3956 input_list = [ifm.name, filter.name, bias.name]
3957 output_list = [result_tens.name]
3958 num_operands = sum(op["operands"])
3959 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3960
3961 TosaErrorValidator.evValidateErrorIfs(
3962 self.ser,
3963 validator_fcns,
3964 error_name,
3965 op=op,
3966 input_dtype=ifm.dtype,
3967 weight_dtype=filter.dtype,
3968 output_dtype=result_tens.dtype,
3969 qinfo=qinfo,
3970 input_list=input_list,
3971 num_operands=num_operands,
3972 output_list=output_list,
3973 pad=padding,
3974 stride=strides,
3975 dilation=dilations,
3976 input_shape=ifm.shape,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003977 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003978
3979 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003980 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07003981
Kevin Cheng550ccc52021-03-03 11:21:43 -08003982 self.ser.addOperator(
Les Bell0e027d42021-11-09 14:42:14 +00003983 op['op'], input_list, output_list, attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003984 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003985 return result_tens
3986
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003987 def build_fully_connected(self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None):
3988 result_tens = OutputShaper.fullyConnectedOp(self.ser, self.rng, ifm, filter, error_name)
3989
3990 # Invalidate Input/Output list for error if checks.
3991 input_list = [ifm.name, filter.name, bias.name]
3992 output_list = [result_tens.name]
3993 pCount, cCount = op["operands"]
3994 num_operands = pCount + cCount
3995 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3996
3997 TosaErrorValidator.evValidateErrorIfs(
3998 self.ser,
3999 validator_fcns,
4000 error_name,
4001 op=op,
4002 input_shape=ifm.shape,
4003 input_dtype=ifm.dtype,
4004 weight_dtype=filter.dtype,
4005 output_shape=result_tens.shape,
4006 output_dtype=result_tens.dtype,
4007 qinfo = qinfo,
4008 result_tensor = result_tens,
4009 input_list=input_list,
4010 output_list=output_list,
4011 num_operands=num_operands,
4012 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004013
Kevin Cheng550ccc52021-03-03 11:21:43 -08004014 self.ser.addOperator(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004015 op['op'], input_list, output_list, None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08004016 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004017 return result_tens
4018
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004019 def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
4020 result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
4021
4022 # Invalidate Input/Output list for error if checks.
4023 input_list = [a.name, b.name]
4024 output_list = [result_tens.name]
4025 pCount, cCount = op["operands"]
4026 num_operands = pCount + cCount
4027 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4028
4029 TosaErrorValidator.evValidateErrorIfs(
4030 self.ser,
4031 validator_fcns,
4032 error_name,
4033 op=op,
4034 input_shape=a.shape,
4035 input_dtype=a.dtype,
4036 input2_shape=b.shape,
4037 input2_dtype=b.dtype,
4038 output_shape=result_tens.shape,
4039 output_dtype=result_tens.dtype,
4040 qinfo = qinfo,
4041 result_tensor = result_tens,
4042 input_list=input_list,
4043 output_list=output_list,
4044 num_operands=num_operands,
4045 )
4046
4047 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07004048 return result_tens
4049
Matthew Haddond6ce7252021-09-29 15:35:44 +01004050 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
4051 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
4052
4053 # Invalidate Input/Output list for error if checks.
4054 input_list = [a.name]
4055 output_list = [result_tens.name]
4056 pCount, cCount = op["operands"]
4057 num_operands = pCount + cCount
4058 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4059
4060 TosaErrorValidator.evValidateErrorIfs(
4061 self.ser,
4062 validator_fcns,
4063 error_name,
4064 op=op,
4065 axis = axis,
4066 input_shape = a.shape,
4067 output_shape = result_tens.shape,
4068 input_dtype = a.dtype,
4069 output_dtype = result_tens.dtype,
4070 result_tensor = result_tens,
4071 input_list=input_list,
4072 output_list=output_list,
4073 num_operands=num_operands,
4074 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004075
4076 attr = ts.TosaSerializerAttribute()
4077 attr.AxisAttribute(axis)
4078
Matthew Haddond6ce7252021-09-29 15:35:44 +01004079 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004080 return result_tens
4081
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004082 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
4083 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004084
Jeremy Johnson18e26662021-07-22 16:15:29 +01004085 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07004086
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004087 if error_name == ErrorIf.MaxSmallerMin:
4088 # Make sure the numbers are different to invoke this error
4089 while v[0] == v[1]:
4090 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
4091 max_val = min(v)
4092 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07004093 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004094 max_val = max(v)
4095 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07004096
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004097 # Invalidate Input/Output list for error if checks.
4098 input_list = [a.name]
4099 output_list = [result_tens.name]
4100 pCount, cCount = op["operands"]
4101 num_operands = pCount + cCount
4102 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4103
4104 TosaErrorValidator.evValidateErrorIfs(
4105 self.ser,
4106 validator_fcns,
4107 error_name,
4108 op=op,
4109 max_val=max_val,
4110 min_val=min_val,
4111 input_shape = a.shape,
4112 output_shape = result_tens.shape,
4113 input_dtype = a.dtype,
4114 output_dtype = result_tens.dtype,
4115 result_tensor = result_tens,
4116 input_list=input_list,
4117 output_list=output_list,
4118 num_operands=num_operands,
4119 )
4120
4121 attr = ts.TosaSerializerAttribute()
4122 if a.dtype == DType.FLOAT:
4123 attr.ClampAttribute(0, 0, min_val, max_val)
4124 else:
4125 attr.ClampAttribute(min_val, max_val, 0, 0)
4126
4127 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004128 return result_tens
4129
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004130 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
4131 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004132 attr = ts.TosaSerializerAttribute()
4133
4134 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
4135
Matthew Haddon848efb42021-09-09 12:30:53 +01004136 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004137 return result_tens
4138
4139 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004140 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
4141 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004142
Matthew Haddon848efb42021-09-09 12:30:53 +01004143 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07004144 return result_tens
4145
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004146 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
4147 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
4148
4149 # Invalidate Input/Output list for error if checks.
4150 input_list = [a.name]
4151 output_list = [result_tens.name]
4152 pCount, cCount = op["operands"]
4153 num_operands = pCount + cCount
4154 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4155
4156 TosaErrorValidator.evValidateErrorIfs(
4157 self.ser,
4158 validator_fcns,
4159 error_name,
4160 op=op,
4161 input_shape = a.shape,
4162 output_shape = result_tens.shape,
4163 input_dtype = a.dtype,
4164 output_dtype = result_tens.dtype,
4165 result_tensor = result_tens,
4166 input_list=input_list,
4167 output_list=output_list,
4168 num_operands=num_operands,
4169 )
4170
4171 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004172 return result_tens
4173
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004174 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
4175 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
4176
4177 # Invalidate Input/Output list for error if checks.
4178 input_list = [a.name]
4179 output_list = [result_tens.name]
4180 pCount, cCount = op["operands"]
4181 num_operands = pCount + cCount
4182 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4183
4184 TosaErrorValidator.evValidateErrorIfs(
4185 self.ser,
4186 validator_fcns,
4187 error_name,
4188 op=op,
4189 input_shape = a.shape,
4190 output_shape = result_tens.shape,
4191 input_dtype = a.dtype,
4192 output_dtype = result_tens.dtype,
4193 result_tensor = result_tens,
4194 input_list=input_list,
4195 output_list=output_list,
4196 num_operands=num_operands,
4197 )
4198
4199 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004200 return result_tens
4201
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004202 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
4203 if error_name != ErrorIf.WrongInputType:
4204 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01004205
4206 # To store variable length list of input tensors we need to store axis along with it
4207 axis = a[-1]
4208 a = a[:-1]
4209
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004210 result_tens = OutputShaper.concatOp(self.ser, self.rng, axis, *a, error_name=error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004211
Matthew Haddon818ab902021-07-27 09:12:49 +01004212 input_tensor_names = []
4213 for tensor in a:
4214 input_tensor_names.append(tensor.name)
4215
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004216 # Invalidate Input/Output list for error if checks.
4217 input_list = input_tensor_names
4218 output_list = [result_tens.name]
4219 pCount, cCount = op["operands"]
4220 num_operands = pCount + cCount
4221 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4222
4223 TosaErrorValidator.evValidateErrorIfs(
4224 self.ser,
4225 validator_fcns,
4226 error_name,
4227 op=op,
4228 axis=axis,
4229 input_shape = a[0].shape,
4230 output_shape = result_tens.shape,
4231 input_dtype = a[0].dtype,
4232 output_dtype = result_tens.dtype,
4233 inputs=a,
4234 result_tensor = result_tens,
4235 input_list=input_list,
4236 output_list=output_list,
4237 num_operands=num_operands,
4238 )
4239
4240 attr = ts.TosaSerializerAttribute()
4241 attr.AxisAttribute(axis)
4242
4243
4244 self.ser.addOperator(op['op'], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01004245 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07004246
Kevin Chengfe392ce2021-10-18 21:51:55 +00004247 def build_pad(self, op, a, padding, pad_const_int, pad_const_float, validator_fcns=None, error_name=None, qinfo=None):
Matthew Haddone807aae2021-10-11 18:12:58 +01004248 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004249
Kevin Chengfe392ce2021-10-18 21:51:55 +00004250 attr = ts.TosaSerializerAttribute()
4251 attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07004252
Matthew Haddone807aae2021-10-11 18:12:58 +01004253 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00004254 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01004255 output_list = [result_tens.name]
4256 pCount, cCount = op["operands"]
4257 num_operands = pCount + cCount
4258 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4259
4260 TosaErrorValidator.evValidateErrorIfs(
4261 self.ser,
4262 validator_fcns,
4263 error_name,
4264 op=op,
4265 input_shape = a.shape,
4266 output_shape = result_tens.shape,
4267 input_dtype = a.dtype,
4268 output_dtype = result_tens.dtype,
4269 pad=padding,
4270 qinfo=qinfo,
4271 result_tensor = result_tens,
4272 input_list=input_list,
4273 output_list=output_list,
4274 num_operands=num_operands,
4275 )
4276
Kevin Cheng550ccc52021-03-03 11:21:43 -08004277 self.ser.addOperator(
Kevin Chengfe392ce2021-10-18 21:51:55 +00004278 op['op'], input_list, output_list, attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08004279 )
Matthew Haddone86fd342021-09-07 16:12:21 +01004280 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07004281
Matthew Haddone807aae2021-10-11 18:12:58 +01004282 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
4283 result_tens = OutputShaper.reshapeOp(self.ser, self.rng, a, newShape, error_name)
4284
4285 # Invalidate Input/Output list for error if checks.
4286 input_list = [a.name]
4287 output_list = [result_tens.name]
4288 pCount, cCount = op["operands"]
4289 num_operands = pCount + cCount
4290 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4291
4292 TosaErrorValidator.evValidateErrorIfs(
4293 self.ser,
4294 validator_fcns,
4295 error_name,
4296 op=op,
4297 input_shape = a.shape,
4298 output_shape = result_tens.shape,
4299 input_dtype = a.dtype,
4300 output_dtype = result_tens.dtype,
4301 result_tensor = result_tens,
4302 input_list=input_list,
4303 output_list=output_list,
4304 num_operands=num_operands,
4305 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004306
4307 attr = ts.TosaSerializerAttribute()
4308 attr.ReshapeAttribute(newShape)
4309
Matthew Haddone807aae2021-10-11 18:12:58 +01004310 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004311 return result_tens
4312
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004313 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
4314 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
4315
4316 # Invalidate Input/Output list for error if checks.
4317 input_list = [a.name]
4318 output_list = [result_tens.name]
4319 pCount, cCount = op["operands"]
4320 num_operands = pCount + cCount
4321 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4322
4323 TosaErrorValidator.evValidateErrorIfs(
4324 self.ser,
4325 validator_fcns,
4326 error_name,
4327 op=op,
4328 axis=axis,
4329 input_shape = a.shape,
4330 output_shape = result_tens.shape,
4331 input_dtype = a.dtype,
4332 output_dtype = result_tens.dtype,
4333 result_tensor = result_tens,
4334 input_list=input_list,
4335 output_list=output_list,
4336 num_operands=num_operands,
4337 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004338
4339 attr = ts.TosaSerializerAttribute()
4340 attr.AxisAttribute(axis)
4341
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004342 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004343 return result_tens
4344
Matthew Haddone807aae2021-10-11 18:12:58 +01004345 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
4346 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004347
Kevin Chengfe392ce2021-10-18 21:51:55 +00004348 attr = ts.TosaSerializerAttribute()
4349 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07004350
Matthew Haddone807aae2021-10-11 18:12:58 +01004351 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00004352 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01004353 output_list = [result_tens.name]
4354 pCount, cCount = op["operands"]
4355 num_operands = pCount + cCount
4356 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4357
4358 TosaErrorValidator.evValidateErrorIfs(
4359 self.ser,
4360 validator_fcns,
4361 error_name,
4362 op=op,
4363 input_shape = a.shape,
4364 output_shape = result_tens.shape,
4365 perms=perms,
4366 input_dtype = a.dtype,
4367 output_dtype = result_tens.dtype,
4368 result_tensor = result_tens,
4369 input_list=input_list,
4370 output_list=output_list,
4371 num_operands=num_operands,
4372 )
4373
4374
Kevin Chengfe392ce2021-10-18 21:51:55 +00004375 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004376 return result_tens
4377
Matthew Haddone807aae2021-10-11 18:12:58 +01004378 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
4379 result_tens = OutputShaper.sliceOp(self.ser, self.rng, a, start, size, error_name)
4380
4381 # Invalidate Input/Output list for error if checks.
4382 input_list = [a.name]
4383 output_list = [result_tens.name]
4384 pCount, cCount = op["operands"]
4385 num_operands = pCount + cCount
4386 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4387
4388 TosaErrorValidator.evValidateErrorIfs(
4389 self.ser,
4390 validator_fcns,
4391 error_name,
4392 op=op,
4393 input_shape = a.shape,
4394 output_shape = result_tens.shape,
4395 input_dtype = a.dtype,
4396 output_dtype = result_tens.dtype,
4397 start=start,
4398 size=size,
4399 result_tensor = result_tens,
4400 input_list=input_list,
4401 output_list=output_list,
4402 num_operands=num_operands,
4403 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004404
4405 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01004406 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07004407
Matthew Haddone807aae2021-10-11 18:12:58 +01004408 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004409 return result_tens
4410
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004411 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
4412 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
4413
4414 # Invalidate Input/Output list for error if checks.
4415 input_list = [a.name]
4416 output_list = [result_tens.name]
4417 pCount, cCount = op["operands"]
4418 num_operands = pCount + cCount
4419 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4420
4421 TosaErrorValidator.evValidateErrorIfs(
4422 self.ser,
4423 validator_fcns,
4424 error_name,
4425 op=op,
4426 input_shape = a.shape,
4427 output_shape = result_tens.shape,
4428 input_dtype = a.dtype,
4429 output_dtype = result_tens.dtype,
4430 result_tensor = result_tens,
4431 input_list=input_list,
4432 output_list=output_list,
4433 num_operands=num_operands,
4434 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004435
4436 attr = ts.TosaSerializerAttribute()
4437 attr.TileAttribute(multiples)
4438
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004439 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004440 return result_tens
4441
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004442 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004443
4444 # Create a new indicies tensor
4445 # here with data that doesn't exceed the dimensions of the values tensor
4446
Kevin Cheng550ccc52021-03-03 11:21:43 -08004447 K = values.shape[1] # K
4448 W = self.randInt(
4449 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
4450 ) # W
4451 indicies_arr = np.int32(
4452 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
4453 ) # (N, W)
4454 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004455
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004456 result_tens = OutputShaper.gatherOp(self.ser, self.rng, values, indicies, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004457
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004458 # Invalidate Input/Output list for error if checks.
4459 input_list = [values.name, indicies.name]
4460 output_list = [result_tens.name]
4461 pCount, cCount = op["operands"]
4462 num_operands = pCount + cCount
4463 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4464
4465 TosaErrorValidator.evValidateErrorIfs(
4466 self.ser,
4467 validator_fcns,
4468 error_name,
4469 op=op,
4470 input_shape = values.shape,
4471 output_shape = result_tens.shape,
4472 input_dtype = values.dtype,
4473 output_dtype = result_tens.dtype,
4474 result_tensor = result_tens,
4475 input_list=input_list,
4476 output_list=output_list,
4477 num_operands=num_operands,
4478 )
4479
4480 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004481
4482 return result_tens
4483
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004484 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08004485
4486 # Create a new indicies tensor
4487 # here with data that doesn't exceed the dimensions of the values_in tensor
4488
Kevin Cheng550ccc52021-03-03 11:21:43 -08004489 K = values_in.shape[1] # K
4490 W = input.shape[1] # W
4491 indicies_arr = np.int32(
4492 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
4493 ) # (N, W)
4494 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004495
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004496 result_tens = OutputShaper.scatterOp(self.ser, self.rng, values_in, indicies, input, error_name)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004497
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004498 # Invalidate Input/Output list for error if checks.
4499 input_list = [values_in.name, indicies.name, input.name]
4500 output_list = [result_tens.name]
4501 pCount, cCount = op["operands"]
4502 num_operands = pCount + cCount
4503 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4504
4505 TosaErrorValidator.evValidateErrorIfs(
4506 self.ser,
4507 validator_fcns,
4508 error_name,
4509 op=op,
4510 input_shape = input.shape,
4511 output_shape = result_tens.shape,
4512 input_dtype = input.dtype,
4513 output_dtype = result_tens.dtype,
4514 result_tensor = result_tens,
4515 input_list=input_list,
4516 output_list=output_list,
4517 num_operands=num_operands,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004518 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08004519
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004520 self.ser.addOperator(op['op'], input_list, output_list)
4521
Kevin Cheng77d0f762020-11-24 10:26:32 -08004522 return result_tens
4523
Matthew Haddon848efb42021-09-09 12:30:53 +01004524
Kevin Cheng550ccc52021-03-03 11:21:43 -08004525 def build_resize(
4526 self,
4527 op,
4528 input,
4529 mode,
4530 stride,
4531 offset,
4532 shift,
4533 stride_fp,
4534 offset_fp,
4535 output_dims,
4536 input_dtype,
4537 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01004538 validator_fcns,
4539 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004540 ):
4541 result_tens = OutputShaper.resizeOp(
4542 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004543 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004544 input,
4545 mode,
4546 stride,
4547 offset,
4548 shift,
4549 stride_fp,
4550 offset_fp,
4551 output_dims,
4552 input_dtype,
4553 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01004554 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08004555 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004556
Matthew Haddon848efb42021-09-09 12:30:53 +01004557 # Invalidate Input/Output list for error if checks.
4558 input_list = [input.name]
4559 output_list = [result_tens.name]
4560 pCount, cCount = op["operands"]
4561 num_operands = pCount + cCount
4562 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
Matthew Haddone86fd342021-09-07 16:12:21 +01004563
Matthew Haddon848efb42021-09-09 12:30:53 +01004564 TosaErrorValidator.evValidateErrorIfs(
4565 self.ser,
4566 validator_fcns,
4567 error_name,
4568 op=op,
4569 mode=mode,
4570 shift=shift,
4571 input_dtype=input_dtype,
4572 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004573 input_shape=input.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01004574 output_shape=output_dims,
4575 offset=offset,
4576 offset_fp=offset_fp,
4577 stride=stride,
4578 stride_fp=stride_fp,
4579 input_list=input_list,
4580 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004581 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01004582 num_operands=num_operands,
4583 )
Matthew Haddone86fd342021-09-07 16:12:21 +01004584
Eric Kunzee5e26762020-10-13 16:11:07 -07004585 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08004586
Kevin Cheng550ccc52021-03-03 11:21:43 -08004587 attr.ResizeAttribute(
4588 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
4589 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004590
Matthew Haddon848efb42021-09-09 12:30:53 +01004591 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004592 return result_tens
4593
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004594 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
4595 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
4596 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004597 self.ser.addOperator(
4598 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
4599 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004600 return result_tens
4601
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004602 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07004603 self.ser.addOutputTensor(val)
4604 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07004605
4606 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004607 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
4608 result_tens = OutputShaper.typeConversionOp(self.ser, self.rng, val, out_dtype, error_name)
4609
4610 # Invalidate Input/Output list for error if checks.
4611 input_list = [val.name]
4612 output_list = [result_tens.name]
4613 pCount, cCount = op["operands"]
4614 num_operands = pCount + cCount
4615 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4616
4617 TosaErrorValidator.evValidateErrorIfs(
4618 self.ser,
4619 validator_fcns,
4620 error_name,
4621 op=op,
4622 input_shape = val.shape,
4623 output_shape = result_tens.shape,
4624 input_dtype = val.dtype,
4625 output_dtype = result_tens.dtype,
4626 result_tensor = result_tens,
4627 input_list=input_list,
4628 output_list=output_list,
4629 num_operands=num_operands,
4630 )
4631
4632 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004633 return result_tens
4634
Matthew Haddonc2025212021-10-08 21:21:05 +01004635 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel, validator_fcns, error_name):
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004636 result_tens = OutputShaper.typeConversionOp(self.ser, self.rng, val, out_dtype, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004637
4638 if per_channel:
4639 nc = val.shape[-1]
4640 else:
4641 nc = 1
4642
4643 in_type_width = self.typeWidth(val.dtype)
4644 out_type_width = self.typeWidth(out_dtype)
4645
Kevin Cheng3a478572021-01-22 17:21:02 -08004646 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004647 input_zp = self.randInt(-128, 128)
4648 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07004649 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004650 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07004651 in_type_width = in_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01004652 elif error_name == ErrorIf.InputZeroPointNotZero:
4653 input_zp = self.randInt(-128, 128)
4654 if input_zp == 0:
4655 input_zp = input_zp + self.rng.integers(1, 10)
4656 in_type_width = in_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004657 else:
4658 input_zp = 0
4659
Kevin Cheng3a478572021-01-22 17:21:02 -08004660 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004661 output_zp = self.randInt(-128, 128)
4662 out_type_width = out_type_width + 1
4663 elif out_dtype == DType.UINT8:
4664 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07004665 out_type_width = out_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01004666 elif error_name == ErrorIf.OutputZeroPointNotZero:
4667 output_zp = self.randInt(-128, 128)
4668 if output_zp == 0:
4669 output_zp = output_zp + self.rng.integers(1, 10)
4670 out_type_width = out_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004671 else:
4672 output_zp = 0
4673
4674 # Calculate scale based on:
4675 # scale = a *(2^output_width)/(2^input_width))
4676
4677 a = np.float32(self.rng.random(size=[nc]))
4678 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
4679
4680 if scale32:
4681 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01004682 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07004683 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
4684 else:
4685 # Cap the scaling at 2^15 - 1 for scale16
4686 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
4687
Kevin Cheng550ccc52021-03-03 11:21:43 -08004688 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07004689
4690 multiplier_arr = np.int32(np.zeros(shape=[nc]))
4691 shift_arr = np.int32(np.zeros(shape=[nc]))
4692
4693 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004694 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
4695 scale_arr[i], scale32
4696 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004697
Kevin Cheng550ccc52021-03-03 11:21:43 -08004698 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07004699
Matthew Haddonc2025212021-10-08 21:21:05 +01004700 # Invalidate Input/Output list for error if checks.
4701 input_list = [val.name]
4702 output_list = [result_tens.name]
4703 pCount, cCount = op["operands"]
4704 num_operands = pCount + cCount
4705 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4706
4707 qinfo = (input_zp, output_zp)
4708 TosaErrorValidator.evValidateErrorIfs(
4709 self.ser,
4710 validator_fcns,
4711 error_name,
4712 op=op,
4713 input_dtype=val.dtype,
4714 output_dtype=out_dtype,
4715 input_shape=val.shape,
4716 qinfo=qinfo,
4717 scale32 = scale32,
4718 double_round = double_round,
4719 input_list=input_list,
4720 output_list=output_list,
4721 result_tensor=result_tens,
4722 num_operands=num_operands,
4723 )
4724
Eric Kunzee5e26762020-10-13 16:11:07 -07004725 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004726 attr.RescaleAttribute(
4727 input_zp,
4728 output_zp,
4729 multiplier_arr,
4730 shift_arr,
4731 scale32,
4732 double_round,
4733 per_channel,
4734 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004735
Matthew Haddonc2025212021-10-08 21:21:05 +01004736 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004737 return result_tens
4738
Matthew Haddon630c17c2021-10-14 15:05:41 +01004739 def build_cond_if_const(self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004740 # For cond_if with constants, we're supplied with then/else tensors that we ignore
4741 # (except for the generated shap) and the condition. Build Then/Else blocks
4742 # and fill them with const nodes for the body.
4743
4744 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08004745 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07004746
4747 # Make then/else tensors
4748 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01004749
4750 # Create an incorrect output shape for error_if tests
4751 if error_name in [ErrorIf.CondIfOutputListThenGraphMismatch, ErrorIf.CondIfOutputListElseGraphMismatch]:
4752 incorrect_shape = deepcopy(then_tens.shape)
4753 for i in range(len(incorrect_shape)):
4754 incorrect_shape[i] = incorrect_shape[i] + self.rng.choice([-3, -2, 2, 3])
4755 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
4756
Jeremy Johnson18e26662021-07-22 16:15:29 +01004757 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
4758 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07004759
4760 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08004761 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07004762
4763 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08004764 then_block = "THEN_BLOCK"
4765 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07004766 attr = ts.TosaSerializerAttribute()
4767 attr.CondIfAttribute(then_block, else_block)
4768
4769 # Finally, build the op and the two blocks
Matthew Haddon848efb42021-09-09 12:30:53 +01004770 self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004771
4772 self.ser.startBasicBlock(then_block)
4773 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01004774 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
4775 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
4776 else:
4777 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004778 self.ser.addOutputTensor(then_tens)
4779
4780 self.ser.startBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004781 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
4782 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
4783 else:
4784 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004785 self.ser.addOutputTensor(else_tens)
4786
Matthew Haddon630c17c2021-10-14 15:05:41 +01004787 TosaErrorValidator.evValidateErrorIfs(
4788 self.ser,
4789 validator_fcns,
4790 error_name,
4791 op=op,
4792 basicBlocks=self.ser.basicBlocks
4793 )
4794
Eric Kunzee5e26762020-10-13 16:11:07 -07004795 return result_tens
4796
Matthew Haddon630c17c2021-10-14 15:05:41 +01004797 def build_cond_if_binary(self, op, a, b, cond, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004798 # For cond_if with a binary op in the then/else blocks, take a and b and
4799 # alternately add or subtract them based on the condition
4800
4801 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08004802 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07004803
Kevin Cheng550ccc52021-03-03 11:21:43 -08004804 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004805
4806 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08004807 then_block = "THEN_BLOCK"
4808 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07004809 attr = ts.TosaSerializerAttribute()
4810 attr.CondIfAttribute(then_block, else_block)
4811
Matthew Haddon630c17c2021-10-14 15:05:41 +01004812 if error_name in [ErrorIf.CondIfInputListThenGraphMismatch, ErrorIf.CondIfInputListElseGraphMismatch,
4813 ErrorIf.CondIfOutputListElseGraphMismatch, ErrorIf.CondIfOutputListThenGraphMismatch]:
4814 incorrect_shape = a.shape.copy()
4815 for i in range(len(incorrect_shape)):
4816 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
4817 incorrect_block_input = deepcopy(a)
4818 incorrect_block_input.shape = incorrect_shape
4819
4820
Eric Kunzee5e26762020-10-13 16:11:07 -07004821 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08004822 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01004823 op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08004824 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004825
Les Bell6040b4d2021-10-11 12:50:31 +01004826 if a.dtype in (DType.FLOAT, DType.INT32):
4827 then_op, else_op = Op.ADD, Op.SUB
4828 elif a.dtype in (DType.INT8, DType.INT16):
4829 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
4830 else:
4831 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07004832
Les Bell6040b4d2021-10-11 12:50:31 +01004833 for block, op in ((then_block, then_op), (else_block, else_op)):
4834 self.ser.startBasicBlock(block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004835 if ((error_name == ErrorIf.CondIfInputListThenGraphMismatch and block == then_block) or
4836 (error_name == ErrorIf.CondIfInputListElseGraphMismatch and block == else_block)):
4837 self.ser.addInputTensor(incorrect_block_input)
4838 self.ser.addInputTensor(b)
4839 tens = self.ser.addOutput(a.shape, a.dtype)
4840 elif ((error_name == ErrorIf.CondIfOutputListThenGraphMismatch and block == then_block) or
4841 (error_name == ErrorIf.CondIfOutputListElseGraphMismatch and block == else_block)):
4842 self.ser.addInputTensor(a)
4843 self.ser.addInputTensor(b)
4844 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
4845 else:
4846 self.ser.addInputTensor(a)
4847 self.ser.addInputTensor(b)
4848 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01004849 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07004850
Matthew Haddon630c17c2021-10-14 15:05:41 +01004851 TosaErrorValidator.evValidateErrorIfs(
4852 self.ser,
4853 validator_fcns,
4854 error_name,
4855 op=op,
4856 a=a,
4857 b=b,
4858 basicBlocks=self.ser.basicBlocks
4859 )
4860
Eric Kunzee5e26762020-10-13 16:11:07 -07004861 return result_tens
4862
Matthew Haddon630c17c2021-10-14 15:05:41 +01004863 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004864 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07004865
Kevin Cheng550ccc52021-03-03 11:21:43 -08004866 cond_block = "COND_BLOCK"
4867 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07004868
4869 attr = ts.TosaSerializerAttribute()
4870 attr.WhileLoopAttribute(cond_block, body_block)
4871
4872 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08004873 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004874 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08004875 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07004876
4877 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004878 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
4879 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004880 if error_name == ErrorIf.InputListOutputListMismatch:
4881 incorrect_acc = deepcopy(acc)
4882 for i in range(len(incorrect_acc.shape)):
4883 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
4884 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
4885 else:
4886 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004887
4888 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08004889 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01004890 op['op'],
Kevin Cheng550ccc52021-03-03 11:21:43 -08004891 [iter.name, a.name, acc.name],
4892 [iter_out.name, a_out.name, acc_out.name],
4893 attr,
4894 )
Kevin Chengb227ae52021-09-02 13:43:17 -07004895 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07004896
Matthew Haddon630c17c2021-10-14 15:05:41 +01004897 if error_name in [ErrorIf.InputListCondGraphMismatch, ErrorIf.InputListBodyGraphInputMismatch, ErrorIf.InputListBodyGraphOutputMismatch]:
4898 incorrect_iter = deepcopy(iter)
4899 for i in range(len(incorrect_iter.shape)):
4900 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
4901 if len(incorrect_iter.shape) == 0:
4902 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
4903
4904 incorrect_acc = deepcopy(acc)
4905 for i in range(len(incorrect_acc.shape)):
4906 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
4907
Eric Kunzee5e26762020-10-13 16:11:07 -07004908 # COND block (input: iter, output: cond_tens )
4909 self.ser.startBasicBlock(cond_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004910 if error_name == ErrorIf.InputListCondGraphMismatch:
4911 self.ser.addInputTensor(incorrect_iter)
4912 self.ser.addInputTensor(a)
4913 self.ser.addInputTensor(incorrect_acc)
4914 else:
4915 self.ser.addInputTensor(iter)
4916 self.ser.addInputTensor(a)
4917 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004918 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01004919
4920 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
4921 cond_tens = self.ser.addOutput([], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT]))
4922 else:
4923 cond_tens = self.ser.addOutput([], DType.BOOL)
4924
Kevin Cheng550ccc52021-03-03 11:21:43 -08004925 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07004926
4927 # BODY block (input: a, acc, iter, output: a, acc, iter)
4928 # Note that local intermediate tensors need to be declared here for the outputs
4929 self.ser.startBasicBlock(body_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004930 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
4931 self.ser.addInputTensor(incorrect_iter)
4932 self.ser.addInputTensor(a)
4933 self.ser.addInputTensor(incorrect_acc)
4934 else:
4935 self.ser.addInputTensor(iter)
4936 self.ser.addInputTensor(a)
4937 self.ser.addInputTensor(acc)
4938
Kevin Cheng550ccc52021-03-03 11:21:43 -08004939 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01004940
4941 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
4942 iter_body_out = self.ser.addIntermediate(incorrect_iter.shape, incorrect_iter.dtype)
4943 acc_body_out = self.ser.addIntermediate(incorrect_acc.shape, incorrect_acc.dtype)
4944 else:
4945 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
4946 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
4947
Eric Kunzee5e26762020-10-13 16:11:07 -07004948 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
4949 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
4950 self.ser.addOutputTensor(iter_body_out)
4951 self.ser.addOutputTensor(a)
4952 self.ser.addOutputTensor(acc_body_out)
4953
Matthew Haddon630c17c2021-10-14 15:05:41 +01004954 TosaErrorValidator.evValidateErrorIfs(
4955 self.ser,
4956 validator_fcns,
4957 error_name,
4958 op=op,
4959 basicBlocks=self.ser.basicBlocks
4960 )
4961
Eric Kunzee5e26762020-10-13 16:11:07 -07004962 return acc_out
4963
Matthew Haddon1c00b712021-10-01 15:51:03 +01004964 def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
4965 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
4966 default_test_rank_range = range(1, 5)
4967 if not shapeFilter:
4968 shapeFilter = [None]
4969
4970 # Calculate the filters based on what is requested and what the operator allows
4971 rmin, rmax = op["rank"]
4972 if rankFilter is not None:
4973 cleanRankFilter = []
4974 # Ensure rankFilter values are allowed by operator
4975 for rank in rankFilter:
4976 if rank >= rmin and rank <= rmax:
4977 cleanRankFilter.append(rank)
4978 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01004979 # Ensure default behaviour is bounded by default range or by operator,
4980 # whichever is the smaller range of ranks.
4981 opRankRange = range(rmin, rmax + 1)
4982 cleanRankFilter = opRankRange if len(opRankRange) <= len(default_test_rank_range) else default_test_rank_range
Matthew Haddon1c00b712021-10-01 15:51:03 +01004983 else:
4984 cleanRankFilter = range(rmin, rmax + 1)
4985
4986 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004987
Matthew Haddon1c00b712021-10-01 15:51:03 +01004988 if dtypeFilter is not None:
4989 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01004990 # Create list of operator dtypes filtered by requested dtypes
4991 for dtype in dtypes:
4992 if dtype in dtypeFilter or (isinstance(dtype, list) and dtype[0] in dtypeFilter):
Matthew Haddon1c00b712021-10-01 15:51:03 +01004993 cleanDtypeFilter.append(dtype)
4994 else:
4995 cleanDtypeFilter = dtypes
4996
4997 if testType == 'positive':
4998 filterDict = {
4999 'shapeFilter': shapeFilter,
5000 'rankFilter': cleanRankFilter,
5001 'dtypeFilter': cleanDtypeFilter
5002 }
5003 return filterDict
5004 elif testType == 'negative':
Matthew Haddone807aae2021-10-11 18:12:58 +01005005 if validator is not None:
5006 validator_info = validator(check=False, op=op)
5007 else:
5008 return None
5009
Matthew Haddon1c00b712021-10-01 15:51:03 +01005010 error_arguments = validator_info['param_reqs']
5011
5012 #Set parameters as required
5013 if error_arguments['rank'] != None:
5014 rankFilter = error_arguments['rank']
5015 else:
5016 rankFilter = cleanRankFilter
5017
5018 if error_arguments['dtype'] != None:
5019 dtypeFilter = error_arguments['dtype']
5020 else:
5021 dtypeFilter = cleanDtypeFilter
5022
5023 if error_arguments['shape'] != None:
5024 shapeFilter = error_arguments['shape']
5025 else:
5026 shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small
5027
5028 filterDict = {
5029 'shapeFilter': shapeFilter,
5030 'rankFilter': rankFilter,
5031 'dtypeFilter': dtypeFilter
5032 }
5033 return filterDict
5034
5035
Kevin Cheng550ccc52021-03-03 11:21:43 -08005036 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01005037 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08005038 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005039
5040 try:
5041 op = self.TOSA_OP_LIST[opName]
5042 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005043 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07005044
5045 # Initialize a new random number generator
5046 self.rng = np.random.default_rng(self.random_seed)
5047
Kevin Cheng550ccc52021-03-03 11:21:43 -08005048 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005049
Eric Kunzee5e26762020-10-13 16:11:07 -07005050 # Test list consists of a tuple of:
5051 # (opName, testNameStr, dtype, shapeList, argumentsList)
5052 testList = []
Matthew Haddon1c00b712021-10-01 15:51:03 +01005053 if testType == 'negative' and "error_if_validators" in op:
5054 error_if_validators = op["error_if_validators"]
5055 else:
5056 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07005057
Matthew Haddon1c00b712021-10-01 15:51:03 +01005058 for validator in error_if_validators:
5059 if validator is not None:
5060 error_name = validator(check=False, op=op)['error_name']
Matthew Haddon1c00b712021-10-01 15:51:03 +01005061 else:
5062 error_name = None
5063
5064 filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator)
Matthew Haddone807aae2021-10-11 18:12:58 +01005065 if filterDict == None:
5066 return []
Matthew Haddon1c00b712021-10-01 15:51:03 +01005067 cleanRankFilter = filterDict['rankFilter']
5068 cleanDtypeFilter = filterDict['dtypeFilter']
5069 cleanShapeFilter = filterDict['shapeFilter']
5070 #print(f"Filters: S {shapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
5071
5072 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01005073 for t in cleanDtypeFilter:
5074 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01005075 # Filter out by rank
5076 if shape is not None and len(shape) != r:
5077 continue
Matthew Haddon74567092021-07-16 15:38:20 +01005078 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005079 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07005080
Matthew Haddon74567092021-07-16 15:38:20 +01005081 shapeStr = self.shapeStr(shapeList[0])
5082 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07005083
Matthew Haddon74567092021-07-16 15:38:20 +01005084 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
5085 argList = []
5086 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01005087 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07005088 else:
Matthew Haddon74567092021-07-16 15:38:20 +01005089 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07005090
Matthew Haddon74567092021-07-16 15:38:20 +01005091 for argStr, args in argList:
Matthew Haddon1c00b712021-10-01 15:51:03 +01005092 if testType == 'positive':
5093 if argStr:
5094 testStr = "{}_{}_{}_{}".format(
5095 opName, shapeStr, typeStr, argStr
5096 )
5097 else:
5098 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
5099 elif testType == 'negative':
Matthew Haddone86fd342021-09-07 16:12:21 +01005100 if argStr:
5101 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
5102 opName, error_name, shapeStr, typeStr, argStr
5103 )
5104 else:
5105 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005106
5107 testList.append((opName, testStr, t, error_name, shapeList, args))
5108
5109 if testType == 'positive':
5110 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
5111 if "invalid_test_validators" in op:
5112 invalid_test_validators = op["invalid_test_validators"]
5113 clean_testList = []
5114 for test in testList:
5115 for validator_fcn in invalid_test_validators:
5116 remove_test = False
5117 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
5118 remove_test = True
5119 if not remove_test:
5120 clean_testList.append(test)
5121 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07005122
5123 return testList
5124
Matthew Haddone86fd342021-09-07 16:12:21 +01005125
5126 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07005127 try:
5128 op = self.TOSA_OP_LIST[opName]
5129 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005130 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07005131
5132 # Create a serializer
5133 self.createSerializer(opName, testStr)
5134
Kevin Cheng550ccc52021-03-03 11:21:43 -08005135 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01005136 if "error_if_validators" in op:
5137 error_if_validators = op["error_if_validators"]
5138 else:
5139 error_if_validators = None
5140
Kevin Cheng550ccc52021-03-03 11:21:43 -08005141 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07005142 num_operands = pCount + cCount
5143
5144 if isinstance(dtype_or_dtypeList, list):
5145 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07005146 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01005147 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07005148 else:
5149 dtypeList = [dtype_or_dtypeList] * (num_operands)
5150
Kevin Cheng93a16282021-08-31 16:14:03 -07005151 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01005152 assert (
5153 len(shapeList) == num_operands
5154 ), "shapeList length {} must match number of operands {}".format(
5155 len(shapeList), num_operands
5156 )
5157 assert (
5158 len(dtypeList) == num_operands
5159 ), "dtypeList length {} must match number of operands {}".format(
5160 len(dtypeList), num_operands
5161 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005162
5163 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005164 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005165 except KeyError:
5166 qgen = None
5167
5168 # Build the random tensor operands and the test
5169 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08005170
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005171 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005172
5173 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005174 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005175 else:
5176 qinfo = None
5177
5178 try:
5179 if error_if_validators is None:
5180 if qinfo is not None:
5181 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
5182 else:
5183 resultName = build_fcn(self, op, *tens, *testArgs)
5184 else:
5185 if qinfo is not None:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005186 resultName = build_fcn(self, op, *tens, *testArgs, validator_fcns=error_if_validators, error_name=error_name, qinfo=qinfo)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005187 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005188 resultName = build_fcn(self, op, *tens, *testArgs, validator_fcns=error_if_validators, error_name=error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005189 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00005190 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01005191 raise e
5192
5193 if resultName is None:
5194 print("Invalid ERROR_IF tests created")
5195
5196 # Save the serialized test
5197 self.serialize("test")
5198
5199
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005200 def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
Matthew Haddon1c00b712021-10-01 15:51:03 +01005201 pCount, cCount = op["operands"]
5202
5203 tens = []
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005204 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name == None:
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005205 # Make sure the operation does not cause value saturation - where
5206 # the number wraps due to limited number of bits to store the answer
5207 assert (
5208 pCount == 2 and cCount == 0
5209 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005210 placeholders = []
5211 add = (op["op"] == Op.ADD)
5212 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
5213 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
5214 if add:
5215 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
5216 else:
5217 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
5218
5219 # Work out the saturation limits
5220 max_i32 = (1 << 31)-1
5221 min_i32 = -(1 << 31)
5222 max_arr = np.full(shapeList[1], max_i32)
5223 min_arr = np.full(shapeList[1], min_i32)
5224
5225 # Find how much values exceed the maximum/minimums
5226 sat_max_arr = np.maximum(res_arr - max_arr, 0)
5227 sat_min_arr = np.minimum(res_arr - min_arr, 0)
5228
5229 if not add:
5230 # Swap saturation values and negate values as we need to perform opposite operations
5231 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
5232
5233 # Create new array of unsaturated values by clipping values as needed
5234 b_unsat_arr = b_arr
5235 if (sat_max_arr != 0).any():
5236 # Clip values that cause saturation
5237 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
5238 # Reduce axes in unsaturated tensor to match original tensor
5239 for axis, dim in enumerate(b_arr.shape):
5240 if dim != b_unsat_arr.shape[axis]:
5241 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
5242 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
5243
5244 if (sat_min_arr != 0).any():
5245 # Clip values that cause saturation
5246 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
5247 # Reduce axes in unsaturated tensor to match original tensor
5248 for axis, dim in enumerate(b_arr.shape):
5249 if dim != b_unsat_arr.shape[axis]:
5250 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
5251 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
5252
5253 placeholders.append(
5254 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
5255 )
5256 placeholders.append(
5257 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
5258 )
5259
5260 tens.extend(placeholders)
Jeremy Johnson8c06a652021-10-20 15:51:11 +01005261 elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[0] == DType.INT32:
5262 # Limit input tensors with cond_if_binary or while_loop to stop
5263 # saturation of add/sub ops
5264 pRemain = pCount
5265 placeholders = []
5266 for idx, shape in enumerate(shapeList[:]):
5267 arr = self.getRandTensor(shapeList[idx], DType.INT16)
5268 if pRemain > 0:
5269 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
5270 pRemain -= 1
5271 else:
5272 placeholders.append(self.ser.addConst(shape, dtypeList[idx], arr))
5273
5274 tens.extend(placeholders)
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005275 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
5276 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005277 assert (
5278 pCount == 2 and cCount == 0
5279 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08005280
5281 placeholders = []
5282 for idx, shape in enumerate(shapeList[:]):
5283 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07005284 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08005285 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07005286 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08005287 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07005288 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08005289 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005290 elif error_name == ErrorIf.WrongInputType:
5291 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Chengaee1fac2020-11-11 13:54:06 -08005292 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005293 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08005294 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005295 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07005296 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08005297
5298 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01005299 elif op["op"] == Op.SELECT:
5300 # Set datatype of condition tensor to boolean
5301 dtypeList[0] = DType.BOOL
5302 tens.extend(
5303 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
5304 )
5305 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005306 elif op["op"] == Op.INTDIV and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005307 assert (
5308 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01005309 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005310
5311 placeholders = []
5312
Matthew Haddon459443c2021-08-23 16:43:13 +01005313 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005314 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07005315 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005316 while True:
5317 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
5318 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
5319
5320 if (divisor_arr == 0).any():
5321 continue
5322
Kevin Cheng47315e12021-05-13 17:41:28 -07005323 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005324 continue
5325
5326 break
5327
5328 placeholders.append(
5329 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
5330 )
5331 placeholders.append(
5332 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
5333 )
5334
5335 tens.extend(placeholders)
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005336 elif op["op"] == Op.MUL and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005337 assert (
5338 pCount == 2 and cCount == 0
5339 ), "Op.MUL must have 2 placeholders, 0 consts"
5340
5341 if dtypeList[0] == DType.FLOAT:
5342 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
5343 else:
5344 placeholders = []
5345
5346 # Make sure multiply result in int32 range
5347 shift = testArgs[0]
5348 if dtypeList[0] == DType.INT8:
5349 num_bits = 8
5350 elif dtypeList[0] == DType.INT16:
5351 num_bits = 16
5352 elif dtypeList[0] == DType.INT32:
5353 num_bits = 32
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005354 elif error_name == ErrorIf.WrongInputType:
5355 num_bits = 8
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005356 else:
5357 raise Exception("OpMul: invalid input dtype")
5358
5359 for idx, shape in enumerate(shapeList[:]):
5360 low = -(2 ** (num_bits - 1))
5361 high = (2 ** (num_bits - 1)) - 1
5362
5363 a_arr = np.int32(
5364 self.rng.integers(low=low, high=high, size=shapeList[0])
5365 )
5366 b_arr = np.int32(
5367 self.rng.integers(low=low, high=high, size=shapeList[1])
5368 )
5369
5370 i = 0
5371 while True:
5372
5373 a_arr_64 = a_arr.astype(np.int64)
5374 b_arr_64 = b_arr.astype(np.int64)
5375
5376 if shift > 0:
5377 rounding = 1 << (shift - 1)
5378 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
5379 else:
5380 result_arr = a_arr_64 * b_arr_64
5381
5382 if (result_arr > -(2 ** 31)).all() and (
5383 result_arr <= ((2 ** 31) - 1)
5384 ).all():
5385 break
5386
5387 i = i + 1
5388 a_arr = a_arr // 2
5389 b_arr = b_arr // 2
5390
5391 placeholders.append(
5392 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
5393 )
5394 placeholders.append(
5395 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
5396 )
5397
5398 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01005399 elif op["op"] == Op.CONCAT:
5400 count = len(shapeList) - self.args.num_const_inputs_concat
5401 if count < 1:
5402 count = 1
5403 if self.args.num_const_inputs_concat == 0:
5404 count = len(shapeList)
5405
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005406 # Ensure axis is an int
5407 testArgs[0] = int(testArgs[0])
5408
5409 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0], error_name)
5410
Matthew Haddon818ab902021-07-27 09:12:49 +01005411 tens.extend(
5412 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
5413 )
5414 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08005415 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07005416 tens.extend(
5417 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
5418 )
5419 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07005420
Matthew Haddon1c00b712021-10-01 15:51:03 +01005421 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07005422
5423 def createDynamicOpLists(self):
5424
5425 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07005426 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005427
Kevin Cheng1533b852021-09-01 12:51:58 -07005428 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005429 testName = "conv2d_{}x{}".format(k[0], k[1])
5430 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
5431 self.TOSA_OP_LIST[testName]["filter"] = k
5432 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07005433
Kevin Cheng550ccc52021-03-03 11:21:43 -08005434 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
5435 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
5436 "depthwise_conv2d_TEMPLATE"
5437 ].copy()
5438 self.TOSA_OP_LIST[testName]["filter"] = k
5439 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07005440
Kevin Cheng550ccc52021-03-03 11:21:43 -08005441 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
5442 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
5443 "transpose_conv2d_TEMPLATE"
5444 ].copy()
5445 self.TOSA_OP_LIST[testName]["filter"] = k
5446 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07005447
Kevin Cheng1533b852021-09-01 12:51:58 -07005448 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
5449 for k in KERNELS_3D:
5450 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
5451 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
5452 self.TOSA_OP_LIST[testName]["filter"] = k
5453 self.TOSA_OP_LIST[testName]["template"] = False
5454
Eric Kunzee5e26762020-10-13 16:11:07 -07005455 # Delete any templates after having created any dynamic ops
5456 # This is a two-pass operation because it's bad practice to delete
5457 # keys from dictionaries while iterating
5458 keyList = []
5459 for k in self.TOSA_OP_LIST:
5460 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005461 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07005462 keyList.append(k)
5463 continue
5464 except KeyError:
5465 pass
5466
5467 for k in keyList:
5468 del self.TOSA_OP_LIST[k]
5469
5470 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005471 """Fill in default fields for ops if they aren't already specified.
5472 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07005473 for op in self.TOSA_OP_LIST:
5474
5475 # Required fields
5476 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005477 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005478 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005479 raise Exception(
5480 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
5481 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005482
5483 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005484 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005485 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005486 raise Exception(
5487 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
5488 op
5489 )
5490 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005491
5492 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005493 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005494 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005495 raise Exception(
5496 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
5497 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005498
5499 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005500 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005501 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005502 raise Exception(
5503 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
5504 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005505
5506 # Put in default rank range, if missing
5507 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005508 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005509 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005510 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07005511
5512 # Tensor operator list
5513 # 'op': op name
5514 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08005515 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
5516 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07005517 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
5518 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08005519 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07005520
Kevin Cheng550ccc52021-03-03 11:21:43 -08005521 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
5522 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07005523
Kevin Cheng550ccc52021-03-03 11:21:43 -08005524 TYPE_BOOL = [DType.BOOL]
5525 TYPE_FI32 = [DType.FLOAT, DType.INT32]
5526 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
5527 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07005528
Kevin Cheng550ccc52021-03-03 11:21:43 -08005529 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07005530
Kevin Cheng1533b852021-09-01 12:51:58 -07005531 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07005532 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07005533 [DType.INT8, DType.INT8, DType.INT32],
5534 [DType.INT16, DType.INT8, DType.INT48],
5535 DType.FLOAT,
5536 ]
5537
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01005538 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07005539
5540 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08005541 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08005542 "argmax": {
5543 "op": Op.ARGMAX,
5544 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005545 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005546 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
5547 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005548 "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evArgmaxOutputRankMismatch,
5549 TosaErrorValidator.evArgmaxOutputShapeMismatch, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
5550 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005551 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005552 "avg_pool2d": {
5553 "op": Op.AVG_POOL2D,
5554 "operands": (1, 0),
5555 "rank": (4, 4),
5556 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
5557 "qgen": TosaQuantGen.qgUnary,
5558 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00005559 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005560 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
5561 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5562 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
5563 TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005564 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005565 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08005566 "conv2d_TEMPLATE": {
5567 "op": Op.CONV2D,
5568 "operands": (1, 2),
5569 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01005570 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005571 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07005572 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00005573 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
5574 "error_if_validators": (
5575 TosaErrorValidator.evWrongInputType,
5576 TosaErrorValidator.evWrongOutputType,
5577 TosaErrorValidator.evWrongInputList,
5578 TosaErrorValidator.evWrongOutputList,
5579 TosaErrorValidator.evInputZeroPointNotZero,
5580 TosaErrorValidator.evWeightZeroPointNotZero,
5581 TosaErrorValidator.evPadSmallerZero,
5582 TosaErrorValidator.evStrideSmallerOne,
5583 TosaErrorValidator.evDilationSmallerOne,
5584 TosaErrorValidator.evWrongRank,
5585 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005586 "template": True,
5587 },
Kevin Cheng1533b852021-09-01 12:51:58 -07005588 # Templated operator. Filled in by createDynamicOpLists
5589 "conv3d_TEMPLATE": {
5590 "op": Op.CONV3D,
5591 "operands": (1, 2),
5592 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01005593 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07005594 "qgen": TosaQuantGen.qgConv,
5595 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00005596 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
5597 "error_if_validators": (
5598 TosaErrorValidator.evWrongInputType,
5599 TosaErrorValidator.evWrongOutputType,
5600 TosaErrorValidator.evWrongInputList,
5601 TosaErrorValidator.evWrongOutputList,
5602 TosaErrorValidator.evInputZeroPointNotZero,
5603 TosaErrorValidator.evWeightZeroPointNotZero,
5604 TosaErrorValidator.evPadSmallerZero,
5605 TosaErrorValidator.evStrideSmallerOne,
5606 TosaErrorValidator.evDilationSmallerOne,
5607 TosaErrorValidator.evWrongRank,
5608 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07005609 "template": True,
5610 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005611 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08005612 "depthwise_conv2d_TEMPLATE": {
5613 "op": Op.DEPTHWISE_CONV2D,
5614 "operands": (1, 2),
5615 "filter": [1, 1],
5616 "rank": (4, 4),
5617 "build_fcn": (
5618 build_depthwise_conv2d,
5619 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01005620 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005621 ),
5622 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07005623 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00005624 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
5625 "error_if_validators": (
5626 TosaErrorValidator.evWrongInputType,
5627 TosaErrorValidator.evWrongOutputType,
5628 TosaErrorValidator.evWrongInputList,
5629 TosaErrorValidator.evWrongOutputList,
5630 TosaErrorValidator.evInputZeroPointNotZero,
5631 TosaErrorValidator.evWeightZeroPointNotZero,
5632 TosaErrorValidator.evPadSmallerZero,
5633 TosaErrorValidator.evStrideSmallerOne,
5634 TosaErrorValidator.evDilationSmallerOne,
5635 TosaErrorValidator.evWrongRank,
5636 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005637 "template": True,
5638 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005639 "fully_connected": {
5640 "op": Op.FULLY_CONNECTED,
5641 "operands": (1, 2),
5642 "rank": (2, 2),
5643 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
5644 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07005645 "types": TYPE_CONV,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005646 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evWrongRank,
5647 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005648 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005649 "matmul": {
5650 "op": Op.MATMUL,
5651 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07005652 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08005653 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
5654 "qgen": TosaQuantGen.qgMatmul,
5655 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005656 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
5657 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005658 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005659 "max_pool2d": {
5660 "op": Op.MAX_POOL2D,
5661 "operands": (1, 0),
5662 "rank": (4, 4),
5663 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
5664 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00005665 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005666 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
5667 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5668 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005669 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005670 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08005671 "transpose_conv2d_TEMPLATE": {
5672 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07005673 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005674 "rank": (4, 4),
5675 "build_fcn": (
5676 build_transpose_conv2d,
5677 TosaTensorGen.tgTransposeConv2D,
5678 TosaArgGen.agTransposeConv2D,
5679 ),
5680 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07005681 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00005682 "invalid_test_validators": (
5683 TosaInvalidValidator.ivHeightWidthInvalid,
5684 TosaInvalidValidator.ivNonPositiveOutputShape,
5685 ),
5686 "error_if_validators": (
5687 TosaErrorValidator.evWrongInputType,
5688 TosaErrorValidator.evWrongOutputType,
5689 TosaErrorValidator.evWrongInputList,
5690 TosaErrorValidator.evWrongOutputList,
5691 TosaErrorValidator.evInputZeroPointNotZero,
5692 TosaErrorValidator.evWeightZeroPointNotZero,
5693 TosaErrorValidator.evPadSmallerZero,
5694 TosaErrorValidator.evStrideSmallerOne,
5695 TosaErrorValidator.evDilationSmallerOne,
5696 TosaErrorValidator.evWrongRank,
5697 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005698 "template": True,
5699 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005700 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08005701 "clamp": {
5702 "op": Op.CLAMP,
5703 "operands": (1, 0),
5704 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
5705 "types": TYPE_NARROW_INT_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005706 "error_if_validators": (TosaErrorValidator.evMaxSmallerMin, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5707 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005708 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08005709 "sigmoid": {
5710 "op": Op.SIGMOID,
5711 "operands": (1, 0),
5712 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
5713 "types": TYPE_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005714 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5715 TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005716 },
5717 "tanh": {
5718 "op": Op.TANH,
5719 "operands": (1, 0),
5720 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
5721 "types": TYPE_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005722 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5723 TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005724 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005725 # Elementwise Binary Operators
5726 "add": {
5727 "op": Op.ADD,
5728 "operands": (2, 0),
5729 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5730 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005731 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005732 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005733 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005734 "arithmetic_right_shift": {
5735 "op": Op.ARITHMETIC_RIGHT_SHIFT,
5736 "operands": (2, 0),
5737 "build_fcn": (
5738 build_arithmetic_right_shift,
5739 TosaTensorGen.tgBroadcastFuzz,
5740 TosaArgGen.agArithmeticRightShift,
5741 ),
5742 "types": TYPE_INT,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005743 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5744 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005745 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005746 "bitwise_and": {
5747 "op": Op.BITWISE_AND,
5748 "operands": (2, 0),
5749 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5750 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005751 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005752 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005753 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005754 "bitwise_or": {
5755 "op": Op.BITWISE_OR,
5756 "operands": (2, 0),
5757 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5758 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005759 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005760 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005761 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005762 "bitwise_xor": {
5763 "op": Op.BITWISE_XOR,
5764 "operands": (2, 0),
5765 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5766 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005767 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005768 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005769 },
Matthew Haddon459443c2021-08-23 16:43:13 +01005770 "intdiv": {
5771 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005772 "operands": (2, 0),
5773 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5774 "types": [DType.INT32],
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005775 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005776 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005777 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005778 "logical_and": {
5779 "op": Op.LOGICAL_AND,
5780 "operands": (2, 0),
5781 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5782 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005783 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005784 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005785 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005786 "logical_left_shift": {
5787 "op": Op.LOGICAL_LEFT_SHIFT,
5788 "operands": (2, 0),
5789 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5790 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005791 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005792 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005793 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005794 "logical_right_shift": {
5795 "op": Op.LOGICAL_RIGHT_SHIFT,
5796 "operands": (2, 0),
5797 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5798 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005799 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005800 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005801 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005802 "logical_or": {
5803 "op": Op.LOGICAL_OR,
5804 "operands": (2, 0),
5805 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5806 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005807 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005808 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005809 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005810 "logical_xor": {
5811 "op": Op.LOGICAL_XOR,
5812 "operands": (2, 0),
5813 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5814 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005815 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005816 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005817 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005818 "maximum": {
5819 "op": Op.MAXIMUM,
5820 "operands": (2, 0),
5821 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5822 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005823 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005824 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005825 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005826 "minimum": {
5827 "op": Op.MINIMUM,
5828 "operands": (2, 0),
5829 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5830 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005831 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005832 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005833 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005834 "mul": {
5835 "op": Op.MUL,
5836 "operands": (2, 0),
5837 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
5838 "types": TYPE_INT_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005839 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005840 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evRankMismatch, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005841 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005842 "pow": {
5843 "op": Op.POW,
5844 "operands": (2, 0),
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005845 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08005846 "types": TYPE_FP,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005847 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005848 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005849 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005850 "sub": {
5851 "op": Op.SUB,
5852 "operands": (2, 0),
5853 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5854 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005855 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005856 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005857 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005858 "table": {
5859 "op": Op.TABLE,
5860 # Use the automatic generation functions to create the input array
5861 # but create the table tensor in the build function, as it may be
5862 # a different type from the input
5863 "operands": (1, 0),
Kevin Chengfe392ce2021-10-18 21:51:55 +00005864 "build_fcn": (build_table, TosaTensorGen.tgBasic, TosaArgGen.agTable),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005865 "types": [DType.INT8, DType.INT16],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005866 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5867 TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005868 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005869 # Elementwise Unary operators
5870 "abs": {
5871 "op": Op.ABS,
5872 "operands": (1, 0),
5873 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5874 "types": TYPE_FI32,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005875 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5876 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005877 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005878 "bitwise_not": {
5879 "op": Op.BITWISE_NOT,
5880 "operands": (1, 0),
5881 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5882 "types": TYPE_INT,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005883 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5884 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005885 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005886 "ceil": {
5887 "op": Op.CEIL,
5888 "operands": (1, 0),
5889 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5890 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005891 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5892 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005893 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005894 "clz": {
5895 "op": Op.CLZ,
5896 "operands": (1, 0),
5897 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5898 "types": [DType.INT32],
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005899 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5900 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005901 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005902 "exp": {
5903 "op": Op.EXP,
5904 "operands": (1, 0),
5905 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5906 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005907 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5908 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005909 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005910 "floor": {
5911 "op": Op.FLOOR,
5912 "operands": (1, 0),
5913 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5914 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005915 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5916 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005917 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005918 "log": {
5919 "op": Op.LOG,
5920 "operands": (1, 0),
5921 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5922 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005923 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5924 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005925 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005926 "logical_not": {
5927 "op": Op.LOGICAL_NOT,
5928 "operands": (1, 0),
5929 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5930 "types": TYPE_BOOL,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005931 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5932 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005933 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005934 "negate": {
5935 "op": Op.NEGATE,
5936 "operands": (1, 0),
5937 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5938 "qgen": TosaQuantGen.qgUnary,
5939 "types": TYPE_INT_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005940 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
5941 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5942 TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005943 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005944 "reciprocal": {
5945 "op": Op.RECIPROCAL,
5946 "operands": (1, 0),
5947 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5948 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005949 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5950 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005951 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005952 "rsqrt": {
5953 "op": Op.RSQRT,
5954 "operands": (1, 0),
5955 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5956 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005957 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5958 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005959 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005960 # Elementwise Ternary operators
5961 "select": {
5962 "op": Op.SELECT,
5963 "operands": (3, 0),
5964 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
5965 "types": TYPE_FIB,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005966 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5967 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005968 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005969 # Comparison operators
5970 "equal": {
5971 "op": Op.EQUAL,
5972 "operands": (2, 0),
5973 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
5974 "types": TYPE_FI32,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005975 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5976 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005977 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005978 "greater_equal": {
5979 "op": Op.GREATER_EQUAL,
5980 "operands": (2, 0),
5981 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
5982 "types": TYPE_FI32,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005983 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5984 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005985 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005986 "greater": {
5987 "op": Op.GREATER,
5988 "operands": (2, 0),
5989 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
5990 "types": TYPE_FI32,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005991 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5992 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005993 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005994 # Reduction operators
5995 "reduce_all": {
5996 "op": Op.REDUCE_ALL,
5997 "operands": (1, 0),
5998 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
5999 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01006000 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
6001 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
6002 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08006003 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006004 "reduce_any": {
6005 "op": Op.REDUCE_ANY,
6006 "operands": (1, 0),
6007 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6008 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01006009 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
6010 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
6011 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08006012 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006013 "reduce_max": {
6014 "op": Op.REDUCE_MAX,
6015 "operands": (1, 0),
6016 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6017 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01006018 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
6019 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
6020 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08006021 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006022 "reduce_min": {
6023 "op": Op.REDUCE_MAX,
6024 "operands": (1, 0),
6025 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6026 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01006027 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
6028 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
6029 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08006030 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006031 "reduce_product": {
6032 "op": Op.REDUCE_PRODUCT,
6033 "operands": (1, 0),
6034 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6035 "types": TYPE_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01006036 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
6037 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
6038 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08006039 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006040 "reduce_sum": {
6041 "op": Op.REDUCE_SUM,
6042 "operands": (1, 0),
6043 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6044 "types": TYPE_FI32,
Matthew Haddond6ce7252021-09-29 15:35:44 +01006045 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
6046 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
6047 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08006048 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006049 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08006050 "concat": {
6051 "op": Op.CONCAT,
6052 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01006053 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006054 "types": TYPE_FIB,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006055 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evConcatInputRankMismatch,
Matthew Haddon01c359d2021-10-15 16:30:48 +01006056 TosaErrorValidator.evConcatShapeSumMismatch, TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType,
6057 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006058 },
6059 "pad": {
6060 "op": Op.PAD,
6061 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01006062 "rank": (1, 5),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006063 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
6064 "qgen": TosaQuantGen.qgPad,
6065 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01006066 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evPadSmallerZero,
6067 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006068 },
6069 "reshape": {
6070 "op": Op.RESHAPE,
6071 "operands": (1, 0),
6072 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
6073 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01006074 "error_if_validators": (TosaErrorValidator.evTensorSizeInputOutputMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
6075 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006076 },
6077 "reverse": {
6078 "op": Op.REVERSE,
6079 "operands": (1, 0),
6080 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6081 "types": TYPE_FIB,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006082 "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evWrongInputType,
6083 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006084 },
6085 "slice": {
6086 "op": Op.SLICE,
6087 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01006088 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006089 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
6090 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01006091 "error_if_validators": (TosaErrorValidator.evStartSmallerZero, TosaErrorValidator.evSizeSmallerEqualZero, TosaErrorValidator.evStartSizeOutsideBounds,
6092 TosaErrorValidator.evSizeOutputShapeMismatch, TosaErrorValidator.evInputSizeStartLengthMismatch, TosaErrorValidator.evWrongRank,
6093 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006094 },
6095 "tile": {
6096 "op": Op.TILE,
6097 "operands": (1, 0),
6098 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
6099 "types": TYPE_FIB,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006100 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
6101 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006102 },
6103 "transpose": {
6104 "op": Op.TRANSPOSE,
6105 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01006106 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006107 "build_fcn": (
6108 build_transpose,
6109 TosaTensorGen.tgBasic,
6110 TosaArgGen.agTranspose,
6111 ),
6112 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01006113 "error_if_validators": (TosaErrorValidator.evIndexOutsideBounds, TosaErrorValidator.evIndexUsedTwice, TosaErrorValidator.evWrongRank,
6114 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006115 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006116 # Data nodes
6117 "const": {
6118 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07006119 "operands": (0, 1),
6120 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08006121 "types": TYPE_FIB,
6122 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006123 "identity": {
6124 "op": Op.IDENTITY,
6125 "operands": (1, 0),
6126 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6127 "types": TYPE_FIB,
6128 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006129 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08006130 "gather": {
6131 "op": Op.GATHER,
6132 # Only specify 'values' tensor here. 'indices' is generated in op building stage
6133 "operands": (1, 0),
6134 "rank": (3, 3),
6135 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
6136 "types": TYPE_INT_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006137 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
6138 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006139 },
6140 "scatter": {
6141 "op": Op.SCATTER,
6142 # Only specify 'values_in' tensor here.
6143 #'indices' and 'input' are generated in op building stage
6144 "operands": (2, 0),
6145 "rank": (3, 3),
6146 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
6147 "types": TYPE_INT_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006148 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
6149 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006150 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006151 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08006152 "resize": {
6153 "op": Op.RESIZE,
6154 "operands": (1, 0),
6155 "rank": (4, 4),
6156 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
6157 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01006158 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
6159 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
6160 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
Matthew Haddon848efb42021-09-09 12:30:53 +01006161 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01006162 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
6163 TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006164 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006165 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08006166 "cast": {
6167 "op": Op.CAST,
6168 "operands": (1, 0),
6169 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
6170 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006171 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
6172 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006173 },
6174 "rescale": {
6175 "op": Op.RESCALE,
6176 "operands": (1, 0),
Matthew Haddonc2025212021-10-08 21:21:05 +01006177 "rank": (1,4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006178 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01006179 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Matthew Haddonc2025212021-10-08 21:21:05 +01006180 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, TosaErrorValidator.evScaleTrue,
6181 TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
6182 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006183 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006184 # Custom
6185 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08006186 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07006187 # Two varients of cond_if, one that generates one of two constant tensors (no
6188 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
6189 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006190 "cond_if_const": {
6191 "op": Op.COND_IF,
6192 "operands": (0, 2),
6193 "build_fcn": (
6194 build_cond_if_const,
6195 TosaTensorGen.tgBasic,
6196 TosaArgGen.agCondIf,
6197 ),
6198 "types": [DType.BOOL],
Matthew Haddon630c17c2021-10-14 15:05:41 +01006199 "error_if_validators": (TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006200 },
6201 "cond_if_binary": {
6202 "op": Op.COND_IF,
6203 "operands": (2, 0),
6204 "build_fcn": (
6205 build_cond_if_binary,
6206 TosaTensorGen.tgBasic,
6207 TosaArgGen.agCondIf,
6208 ),
Les Bell6040b4d2021-10-11 12:50:31 +01006209 "types": TYPE_INT_FP,
Matthew Haddon630c17c2021-10-14 15:05:41 +01006210 "error_if_validators": (TosaErrorValidator.evInputListThenGraphMismatch, TosaErrorValidator.evInputListElseGraphMismatch,
6211 TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006212 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006213 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08006214 "while_loop": {
6215 "op": Op.WHILE_LOOP,
6216 "operands": (0, 1),
6217 "build_fcn": (
6218 build_while_loop,
6219 TosaTensorGen.tgBasic,
6220 TosaArgGen.agWhileLoop,
6221 ),
6222 "types": [DType.INT32],
Matthew Haddon630c17c2021-10-14 15:05:41 +01006223 "error_if_validators": (TosaErrorValidator.evInputListOutputListMismatch, TosaErrorValidator.evInputListCondGraphMismatch,
6224 TosaErrorValidator.evInputListBodyGraphInputMismatch, TosaErrorValidator.evInputListBodyGraphOutputMismatch,
6225 TosaErrorValidator.evCondGraphOutputNotMatchingBool)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006226 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006227 }
6228
Kevin Cheng550ccc52021-03-03 11:21:43 -08006229
Eric Kunzee5e26762020-10-13 16:11:07 -07006230class OutputShaper:
6231 # Methods in this class compute the expected output shape and datatype
6232 # for common classes of operations
6233 def __init__(self):
6234 pass
6235
6236 # These methods return arguments that can be used for
6237 # creating a new output tensor
6238 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01006239 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
6240 if error_name != ErrorIf.RankMismatch:
6241 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006242 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07006243
6244 shape = []
6245 for i in range(len(a.shape)):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01006246 if a.shape[i] == 1 and error_name == None:
Eric Kunzee5e26762020-10-13 16:11:07 -07006247 shape.append(b.shape[i])
6248 else:
6249 shape.append(a.shape[i])
6250
Matthew Haddoneacff9a2021-09-24 14:42:13 +01006251 if error_name == ErrorIf.WrongOutputType:
6252 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6253 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6254 outputDType = rng.choice(wrong_dtypes)
6255 else:
6256 outputDType = a.dtype
6257
6258 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006259
6260 @staticmethod
6261 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08006262 assert len(a.shape) == len(b.shape)
6263 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07006264
6265 shape = []
6266 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08006267 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07006268 shape.append(a.shape[i])
6269
Kevin Cheng550ccc52021-03-03 11:21:43 -08006270 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006271
6272 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01006273 def unaryOp(ser, rng, a, error_name=None):
6274 if error_name == ErrorIf.WrongOutputType:
6275 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6276 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6277 outputDType = rng.choice(wrong_dtypes)
6278 else:
6279 outputDType = a.dtype
6280
6281 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006282
6283 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006284 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00006285 if error_name != ErrorIf.RankMismatch:
6286 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006287 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07006288
6289 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00006290 for i in range(len(cond.shape)):
6291 if cond.shape[i] == 1 and error_name == None:
6292 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
6293 else:
6294 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07006295
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006296 if error_name == ErrorIf.WrongOutputType:
6297 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6298 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6299 outputDType = rng.choice(wrong_dtypes)
6300 else:
6301 outputDType = a.dtype
6302
6303 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006304
6305 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006306 def binaryComparisonOp(ser, rng, a, b , error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00006307 if error_name != ErrorIf.RankMismatch:
6308 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08006309 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07006310
6311 # Do broadcast
6312 shape = []
6313 for i in range(len(a.shape)):
6314 if a.shape[i] == 1:
6315 shape.append(b.shape[i])
6316 else:
6317 shape.append(a.shape[i])
6318
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006319 if error_name == ErrorIf.WrongOutputType:
6320 wrong_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6321 outputDType = rng.choice(wrong_dtypes)
6322 else:
6323 outputDType = DType.BOOL
6324
6325 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006326
6327 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01006328 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006329 shape = a.shape.copy()
Matthew Haddond6ce7252021-09-29 15:35:44 +01006330 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ShapeOfAxisNotOne]:
6331 shape[axis] = 1
6332 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
6333 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07006334
Matthew Haddond6ce7252021-09-29 15:35:44 +01006335 if error_name == ErrorIf.WrongOutputType:
6336 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6337 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6338 outputDType = rng.choice(wrong_dtypes)
6339 else:
6340 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07006341
Matthew Haddond6ce7252021-09-29 15:35:44 +01006342 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006343
6344 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006345 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006346 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006347
6348 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
6349 del shape[axis]
6350
6351 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
6352 remove = rng.choice([True, False])
6353 if remove and len(shape) > 1:
6354 del shape[0]
6355 else:
6356 shape.append(1)
6357 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
6358 for i in range(len(shape)):
6359 shape[i] = shape[i] + rng.integers(1, 10)
6360
6361 if error_name == ErrorIf.WrongOutputType:
6362 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6363 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
6364 outputDType = rng.choice(wrong_dtypes)
6365 else:
6366 outputDType = DType.INT32
6367
6368 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006369
6370 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00006371 def conv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006372
6373 # IFM: NHWC
6374 # Filter: OHWI
6375 # OFM: NHWC
6376
6377 if len(padding) == 2:
6378 # Expand padding to 4 parameters in the case of transpose_conv2d
6379 # From H,W to T,B,L,R
6380 padding = [padding[0], padding[0], padding[1], padding[1]]
6381
Kevin Cheng550ccc52021-03-03 11:21:43 -08006382 h = (
6383 ifm.shape[1]
6384 - filter.shape[1]
6385 - (filter.shape[1] - 1) * (dilations[0] - 1)
6386 + padding[0]
6387 + padding[1]
6388 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07006389
Kevin Cheng550ccc52021-03-03 11:21:43 -08006390 w = (
6391 ifm.shape[2]
6392 - filter.shape[2]
6393 - (filter.shape[2] - 1) * (dilations[1] - 1)
6394 + padding[2]
6395 + padding[3]
6396 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07006397
Les Bell0e027d42021-11-09 14:42:14 +00006398 # Avoid illegal dimensions, which can be generated in error_if tests
6399 h = max(h, 1)
6400 w = max(w, 1)
6401
Eric Kunzee5e26762020-10-13 16:11:07 -07006402 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
6403
Kevin Cheng3a478572021-01-22 17:21:02 -08006404 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006405 out_dtype = DType.INT32
6406 elif ifm.dtype == DType.INT16:
6407 out_dtype = DType.INT48
6408 elif ifm.dtype == DType.FLOAT:
6409 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00006410 elif error_name == ErrorIf.WrongInputType:
6411 # Pick some potentially correct output dtype if input type is incorrect
6412 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006413 else:
Les Bell0e027d42021-11-09 14:42:14 +00006414 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
6415
6416 if error_name == ErrorIf.WrongOutputType:
6417 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
6418 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07006419
Kevin Cheng550ccc52021-03-03 11:21:43 -08006420 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006421
6422 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00006423 def conv3dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -07006424
6425 # IFM: NDHWC
6426 # Filter: ODHWI
6427 # OFM: NDHWC
6428
6429 d = (
6430 ifm.shape[1]
6431 - filter.shape[1]
6432 - (filter.shape[1] - 1) * (dilations[0] - 1)
6433 + padding[0]
6434 + padding[1]
6435 ) // strides[0] + 1
6436
6437 h = (
6438 ifm.shape[2]
6439 - filter.shape[2]
6440 - (filter.shape[2] - 1) * (dilations[1] - 1)
6441 + padding[2]
6442 + padding[3]
6443 ) // strides[1] + 1
6444
6445 w = (
6446 ifm.shape[3]
6447 - filter.shape[3]
6448 - (filter.shape[3] - 1) * (dilations[2] - 1)
6449 + padding[4]
6450 + padding[5]
6451 ) // strides[2] + 1
6452
Les Bell0e027d42021-11-09 14:42:14 +00006453 # Avoid illegal dimensions, which can be generated in error_if tests
6454 d = max(d, 1)
6455 h = max(h, 1)
6456 w = max(w, 1)
6457
Kevin Cheng1533b852021-09-01 12:51:58 -07006458 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
6459
6460 if ifm.dtype == DType.INT8:
6461 out_dtype = DType.INT32
6462 elif ifm.dtype == DType.INT16:
6463 out_dtype = DType.INT48
6464 elif ifm.dtype == DType.FLOAT:
6465 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00006466 elif error_name == ErrorIf.WrongInputType:
6467 # Pick some potentially correct output dtype if input type is incorrect
6468 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07006469 else:
Les Bell0e027d42021-11-09 14:42:14 +00006470 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
6471
6472 if error_name == ErrorIf.WrongOutputType:
6473 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
6474 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07006475
6476 return ser.addOutput(ofm_shape, out_dtype)
6477
6478 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00006479 def depthwiseConv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006480 # IFM: NHWC
6481 # Filter: HWCM
6482 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08006483 h = (
6484 ifm.shape[1]
6485 - filter.shape[0]
6486 - (filter.shape[0] - 1) * (dilations[0] - 1)
6487 + padding[0]
6488 + padding[1]
6489 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07006490
Kevin Cheng550ccc52021-03-03 11:21:43 -08006491 w = (
6492 ifm.shape[2]
6493 - filter.shape[1]
6494 - (filter.shape[1] - 1) * (dilations[1] - 1)
6495 + padding[2]
6496 + padding[3]
6497 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07006498
Les Bell0e027d42021-11-09 14:42:14 +00006499 # Avoid illegal dimensions, which can be generated in error_if tests
6500 h = max(h, 1)
6501 w = max(w, 1)
6502
Eric Kunzee5e26762020-10-13 16:11:07 -07006503 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
6504
Kevin Cheng3a478572021-01-22 17:21:02 -08006505 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006506 out_dtype = DType.INT32
6507 elif ifm.dtype == DType.INT16:
6508 out_dtype = DType.INT48
6509 elif ifm.dtype == DType.FLOAT:
6510 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00006511 elif error_name == ErrorIf.WrongInputType:
6512 # Pick some potentially correct output dtype if input type is incorrect
6513 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006514 else:
Les Bell0e027d42021-11-09 14:42:14 +00006515 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
6516
6517 if error_name == ErrorIf.WrongOutputType:
6518 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
6519 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07006520
Kevin Cheng550ccc52021-03-03 11:21:43 -08006521 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006522
6523 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01006524 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006525 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01006526 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006527 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01006528 h = 1
6529 w = 1
6530 else:
6531 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
6532 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
6533
6534 if error_name == ErrorIf.PoolingOutputShapeMismatch:
6535 choices = [1, 2, 3, 4, 5]
6536 h = h + rng.choice(choices)
6537 w = w + rng.choice(choices)
Eric Kunzee5e26762020-10-13 16:11:07 -07006538
Eric Kunzee5e26762020-10-13 16:11:07 -07006539 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01006540
6541 if error_name == ErrorIf.WrongOutputType:
6542 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6543 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
6544 outputDType = rng.choice(wrong_dtypes)
6545 else:
6546 outputDType = ifm.dtype
6547
6548 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006549
6550 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006551 def fullyConnectedOp(ser, rng, input, filter, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006552 # input: N, IC
6553 # filter: OC, IC
6554 # output: N, OC
6555
6556 output_shape = [input.shape[0], filter.shape[0]]
6557
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006558 if error_name == ErrorIf.WrongOutputType:
6559 if input.dtype == DType.INT8:
6560 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
6561 elif input.dtype == DType.INT16:
6562 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
6563 elif input.dtype == DType.FLOAT:
6564 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
6565 out_dtype = rng.choice(a=incorrect_types)
6566 elif input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006567 out_dtype = DType.INT32
6568 elif input.dtype == DType.INT16:
6569 out_dtype = DType.INT48
6570 elif input.dtype == DType.FLOAT:
6571 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006572 elif error_name == ErrorIf.WrongInputType:
6573 # Pick some potentially correct output dtype if input type is incorrect
6574 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006575 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08006576 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07006577
Kevin Cheng550ccc52021-03-03 11:21:43 -08006578 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006579
6580 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006581 def matmulOp(ser, rng, a, b, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07006582 # a: N, H, C
6583 # b: N, C, W
6584 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07006585
Kevin Cheng2d60f002021-06-09 14:18:32 -07006586 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07006587
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006588 if error_name == ErrorIf.WrongOutputType:
6589 if a.dtype == DType.INT8:
6590 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
6591 elif a.dtype == DType.INT16:
6592 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
6593 elif a.dtype == DType.FLOAT:
6594 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
6595 out_dtype = rng.choice(a=incorrect_types)
6596 elif a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006597 out_dtype = DType.INT32
6598 elif a.dtype == DType.INT16:
6599 out_dtype = DType.INT48
6600 elif a.dtype == DType.FLOAT:
6601 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006602 elif error_name == ErrorIf.WrongInputType:
6603 # Pick some potentially correct output dtype if input type is incorrect
6604 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006605 else:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006606 raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07006607
Kevin Cheng550ccc52021-03-03 11:21:43 -08006608 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006609
6610 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006611 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01006612 input1 = a[0]
6613 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07006614
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006615 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01006616 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006617 if not (
6618 # unable to concat tensors of different ranks
6619 error_name == ErrorIf.ConcatInputRankMismatch
6620 # unable to concat tensors along an invalid axis
6621 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006622 ):
6623 for tensor in remaining_inputs:
6624 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07006625
Matthew Haddon01c359d2021-10-15 16:30:48 +01006626 if error_name == ErrorIf.ConcatShapeSumMismatch:
6627 output_shape[axis] += rng.integers(5, 10)
6628
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006629 if error_name == ErrorIf.WrongOutputType:
6630 all_dtypes = {DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
6631 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
6632 outputDType = rng.choice(wrong_dtypes)
6633 else:
6634 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01006635
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006636 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006637
6638 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01006639 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006640
6641 output_shape = a.shape.copy()
6642
6643 for i in range(len(output_shape)):
6644 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
6645
Matthew Haddone807aae2021-10-11 18:12:58 +01006646 # Fix negative output shape if error_if test causes it
6647 if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
6648 output_shape = [i if i >= 1 else 1 for i in output_shape]
6649
6650 if error_name == ErrorIf.WrongOutputType:
6651 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6652 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6653 outputDType = rng.choice(wrong_dtypes)
6654 else:
6655 outputDType = a.dtype
6656
6657 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006658
6659 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01006660 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006661 output_shape = shape.copy()
6662
6663 totalElements = 1
6664 for i in a.shape:
6665 totalElements *= i
6666
6667 # If there are any -1 elements, figure out what that dimension must be
6668 totalOutputElements = 1
6669 for i in output_shape:
6670 if i != -1:
6671 totalOutputElements *= i
6672
6673 # And fill it in
6674 for i in range(len(output_shape)):
6675 if output_shape[i] == -1:
6676 output_shape[i] = totalElements // totalOutputElements
6677
Matthew Haddone807aae2021-10-11 18:12:58 +01006678 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
6679 for i in range(len(output_shape)):
6680 output_shape[i] = output_shape[i] + rng.integers(1, 10)
6681
6682 if error_name == ErrorIf.WrongOutputType:
6683 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6684 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6685 outputDType = rng.choice(wrong_dtypes)
6686 else:
6687 outputDType = a.dtype
6688
6689 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006690
6691 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01006692 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006693
Matthew Haddone807aae2021-10-11 18:12:58 +01006694 if error_name == ErrorIf.WrongOutputType:
6695 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6696 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6697 outputDType = rng.choice(wrong_dtypes)
6698 else:
6699 outputDType = a.dtype
6700
6701 if error_name == ErrorIf.SizeOutputShapeMismatch:
6702 output_shape = size.copy()
6703 for index in range(len(output_shape)):
6704 if output_shape[index] <= 2:
6705 output_shape[index] = output_shape[index] + rng.choice([1, 2])
6706 else:
6707 output_shape[index] = output_shape[index] + rng.choice([-2, -1, 1, 2])
6708 else:
6709 output_shape = size.copy()
6710
6711 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006712
6713 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006714 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006715
6716 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08006717 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07006718
6719 for i in range(len(output_shape)):
6720 output_shape[i] = a.shape[i] * multiples[i]
6721
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006722 if error_name == ErrorIf.WrongOutputType:
6723 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6724 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6725 outputDType = rng.choice(wrong_dtypes)
6726 else:
6727 outputDType = a.dtype
6728
6729 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006730
6731 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01006732 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006733 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01006734
Kevin Cheng550ccc52021-03-03 11:21:43 -08006735 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07006736
Matthew Haddone807aae2021-10-11 18:12:58 +01006737 if error_name == ErrorIf.IndexOutsideBounds:
6738 for i in range(len(output_shape)):
6739 output_shape[i] = a.shape[0]
6740 else:
6741 for i in range(len(output_shape)):
6742 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07006743
Matthew Haddone807aae2021-10-11 18:12:58 +01006744 if error_name == ErrorIf.WrongOutputType:
6745 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6746 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6747 outputDType = rng.choice(wrong_dtypes)
6748 else:
6749 outputDType = a.dtype
6750
6751 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006752
6753 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006754 def gatherOp(ser, rng, values, indices, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08006755 assert len(values.shape) == 3
6756 assert len(indices.shape) == 2
6757 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07006758
Kevin Cheng77d0f762020-11-24 10:26:32 -08006759 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
6760
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006761 if error_name == ErrorIf.WrongOutputType:
6762 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6763 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
6764 outputDType = rng.choice(wrong_dtypes)
6765 else:
6766 outputDType = values.dtype
6767
6768 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08006769
6770 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006771 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08006772 assert len(values_in.shape) == 3
6773 assert len(indices.shape) == 2
6774 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08006775 assert values_in.shape[0] == indices.shape[0] # N
6776 assert input.shape[1] == indices.shape[1] # W
6777 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08006778
6779 output_shape = values_in.shape
6780
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006781 if error_name == ErrorIf.WrongOutputType:
6782 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6783 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
6784 outputDType = rng.choice(wrong_dtypes)
6785 else:
6786 outputDType = values_in.dtype
6787
6788 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006789
6790 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006791 def tableOp(ser, rng, input, error_name=None):
6792 # Same shape as the input, dtype dependent on input dtype
6793 if error_name != ErrorIf.WrongInputType:
6794 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00006795 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006796 if error_name == ErrorIf.WrongOutputType:
6797 wrong_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6798 wrong_dtypes.remove(output_dtype)
6799 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01006800 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006801
6802 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08006803 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01006804 serializer,
6805 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08006806 input,
6807 mode,
6808 stride,
6809 offset,
6810 shift,
6811 stride_fp,
6812 offset_fp,
6813 output_dims,
6814 input_dtype,
6815 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01006816 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08006817 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01006818 if error_name == ErrorIf.WrongRank:
6819 output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
6820 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01006821 if error_name == ErrorIf.BatchMismatch:
6822 output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
6823 elif error_name == ErrorIf.ChannelMismatch:
6824 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
6825 else:
6826 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07006827
Matthew Haddon693ba9e2021-09-22 11:24:37 +01006828 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006829
6830 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006831 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08006832 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006833
6834 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00006835 def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None):
Kevin Cheng3a478572021-01-22 17:21:02 -08006836 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006837 out_dtype = DType.INT32
6838 elif ifm.dtype == DType.INT16:
6839 out_dtype = DType.INT48
6840 elif ifm.dtype == DType.FLOAT:
6841 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00006842 elif error_name == ErrorIf.WrongInputType:
6843 # Pick some potentially correct output dtype if input type is incorrect
6844 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006845 else:
Les Bell0e027d42021-11-09 14:42:14 +00006846 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
6847
6848 if error_name == ErrorIf.WrongOutputType:
6849 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
6850 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07006851
Kevin Cheng550ccc52021-03-03 11:21:43 -08006852 return ser.addOutput(output_shape, out_dtype)