blob: db443281d10fe9313f63ae17b8d531b3a4f47a9c [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
Eric Kunzee5e26762020-10-13 16:11:07 -070017import numpy as np
18import argparse
19import sys
20import re
21import os
22import subprocess
23import shlex
24import json
25import glob
26import math
27import queue
28import threading
29import traceback
30import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010031import itertools
Matthew Haddon630c17c2021-10-14 15:05:41 +010032from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
48DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080049Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070050ResizeMode = tosa.ResizeMode.ResizeMode()
51
Matthew Haddon630c17c2021-10-14 15:05:41 +010052
53def product(shape):
54 value = 1
55 for n in shape:
56 value *= n
57 return value
58
Eric Kunzee5e26762020-10-13 16:11:07 -070059class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080060 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
61
Eric Kunzee5e26762020-10-13 16:11:07 -070062 def __init__(self):
63 pass
64
65 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010066 def getQinfo(testGen, dtype, error_name=None):
67
Les Bell30e46802021-07-23 09:43:31 +010068 if dtype == DType.INT8:
69 return testGen.randInt(-128, 128)
Matthew Haddone4ecdb22021-09-28 11:38:21 +010070 elif dtype == DType.UINT8:
Les Bell30e46802021-07-23 09:43:31 +010071 return testGen.randInt(0, 256)
Matthew Haddonc4cf0372021-10-11 09:38:10 +010072 elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.WeightZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
Matthew Haddone4ecdb22021-09-28 11:38:21 +010073 zero_point = testGen.randInt(-128, 128)
74 if zero_point == 0:
75 zero_point = 1
76 return zero_point
Les Bell30e46802021-07-23 09:43:31 +010077 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070078
79 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010080 def qgUnary(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070081 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone4ecdb22021-09-28 11:38:21 +010082 if error_name == ErrorIf.InputZeroPointNotZero:
83 qinfo.UnaryQuantInfo(
84 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype)
85 )
86 elif error_name == ErrorIf.OutputZeroPointNotZero:
87 qinfo.UnaryQuantInfo(
88 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype, error_name)
89 )
90 else:
91 qinfo.UnaryQuantInfo(
92 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
93 )
Eric Kunzee5e26762020-10-13 16:11:07 -070094 return qinfo
95
96 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010097 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070098 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010099 if isinstance(dtype_or_dtypeList, list):
100 # a list of [input, weights, accumulator] dtypes
101 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -0700102 else:
Les Bell30e46802021-07-23 09:43:31 +0100103 # an int, [input, weights, accumulator] dtypes are the same
104 dtypeList = [dtype_or_dtypeList] * 3
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100105
106 if error_name == ErrorIf.InputZeroPointNotZero:
107 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
108 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
109 elif error_name == ErrorIf.WeightZeroPointNotZero:
110 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
111 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
112 else:
113 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
114 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
115
Les Bell30e46802021-07-23 09:43:31 +0100116 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -0700117 return qinfo
118
119 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100120 def qgMatmul(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700121 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100122 if error_name == ErrorIf.InputZeroPointNotZero:
123 qinfo.MatMulQuantInfo(
124 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype, error_name)
Kevin Chengacb550f2021-06-29 15:32:19 -0700125 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100126 else:
127 qinfo.MatMulQuantInfo(
128 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
129 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700130 return qinfo
131
132 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100133 def qgPad(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700134 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone807aae2021-10-11 18:12:58 +0100135 if error_name == ErrorIf.InputZeroPointNotZero:
136 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name))
137 else:
138 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700139 return qinfo
140
141 @staticmethod
142 def computeMultiplierAndShift(scaleFp, scale32):
143 # Derived from computeMultiplierAndShiftTosaScale32
144 # Provide a floating-point scaling factor and the scale32 parameter
145 # to compute the multiplier and shift
146
147 if scale32:
148 scaleBits = 31
149 else:
150 scaleBits = 15
151
152 m, shift = math.frexp(scaleFp)
153
154 if scaleFp < 0.0:
155 m = -m
156
157 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800158 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700159
160 if multiplier == (1 << scaleBits):
161 multiplier = multiplier // 2
162 shift = shift + 1
163
164 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100165 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
166
167 # Adjust multiplier such that shift is in allowed value range.
168 if shift == 0:
169 multiplier = multiplier // 4
170 shift = shift + 2
171 elif shift == 1:
172 multiplier = multiplier // 2
173 shift = shift + 1
174 elif shift == 63:
175 multiplier = multiplier * 2
176 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700177
Kevin Cheng550ccc52021-03-03 11:21:43 -0800178 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100179 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700180
181 return multiplier, shift
182
183
Kevin Cheng550ccc52021-03-03 11:21:43 -0800184class TosaTensorGen:
185 """Tensor generators create a shape list for the placeholder and const tensor
186 data operands for the operator. The actual random data is generated separately for each test."""
187
Eric Kunzee5e26762020-10-13 16:11:07 -0700188 def __init__(self):
189 pass
190
191 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100192 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800193 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700194 shape = testGen.makeShape(rank)
195
Matthew Haddon630c17c2021-10-14 15:05:41 +0100196 # Constrict the overall size of the shape when creating ERROR_IF tests
197 if error_name:
198 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
Matthew Haddonc2025212021-10-08 21:21:05 +0100199
Eric Kunzee5e26762020-10-13 16:11:07 -0700200 shape_list = []
201 for i in range(pl + const):
202 shape_list.append(shape.copy())
203
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100204 if error_name == ErrorIf.RankMismatch:
205 if rank == 1 and i != 1:
206 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
207 elif i != 1:
208 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
209
Eric Kunzee5e26762020-10-13 16:11:07 -0700210 return shape_list
211
212 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100213 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800214 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
Matthew Haddon848efb42021-09-09 12:30:53 +0100216 if error_name != ErrorIf.WrongRank:
217 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700218
219 shape = testGen.makeShape(rank)
220
221 # Constrict the batch size?
222 if testGen.args.max_batch_size:
223 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100224
Matthew Haddon630c17c2021-10-14 15:05:41 +0100225 # Constrict the overall size of the shape when creating ERROR_IF tests
226 if error_name:
227 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
Eric Kunzee5e26762020-10-13 16:11:07 -0700228
229 shape_list = []
230 for i in range(pl + const):
231 shape_list.append(shape.copy())
232
233 return shape_list
234
235 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100236 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800237 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800238
Kevin Cheng550ccc52021-03-03 11:21:43 -0800239 assert pl == 2
240 assert const == 0
241 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800242
243 values_in_shape = testGen.makeShape(rank)
244
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100245 # ignore max batch size if target shape is set
246 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800247 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
248
Kevin Cheng550ccc52021-03-03 11:21:43 -0800249 W = testGen.randInt(
250 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
251 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100252 # Constrict W if one dimension is too large to keep tensor size reasonable
253 if max(values_in_shape) > 5000:
254 W = testGen.randInt(0, 16)
255
Kevin Cheng77d0f762020-11-24 10:26:32 -0800256 input_shape = [values_in_shape[0], W, values_in_shape[2]]
257
258 shape_list = []
259 shape_list.append(values_in_shape.copy())
260 shape_list.append(input_shape.copy())
261
262 return shape_list
263
264 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100265 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700266 shape = testGen.makeShape(rank)
267
Kevin Cheng550ccc52021-03-03 11:21:43 -0800268 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700269
270 shape_list = []
271
272 # Choose one of the inputs to broadcast
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +0000273 # Note: Simplifies OutputShaper code if we don't change first shape for errors
274 bcast_idx = testGen.randInt(0 if error_name == None else 1, pl + const)
Eric Kunzee5e26762020-10-13 16:11:07 -0700275 for i in range(pl + const):
276 shape_bcast = shape.copy()
277
278 # If the chosen input, pick a random index to broadcast
279 if i == bcast_idx:
280 fuzz_idx = testGen.randInt(0, rank)
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +0000281 if error_name == ErrorIf.DimensionMismatch:
282 shape_bcast[fuzz_idx] += 1
283 elif error_name == ErrorIf.RankMismatch:
284 # Add one rank to the shape (or more for rank of 1)
285 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
286 shape_bcast = np.concatenate((shape_bcast, testGen.makeShape(extra_ranks)))
287 if rank != 1:
288 # Either keep the extra rank, or remove it
289 new_len = testGen.rng.choice([-2, len(shape_bcast)])
290 shape_bcast = shape_bcast[:new_len]
291 else:
292 shape_bcast[fuzz_idx] = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700293
294 shape_list.append(shape_bcast)
295
296 return shape_list
297
298 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100299 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800300 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700301
Kevin Cheng550ccc52021-03-03 11:21:43 -0800302 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700303
304 # IFM dimensions are NHWC
305 ifm_shape = testGen.makeShape(rank)
306
307 # Constrict the batch size?
308 if testGen.args.max_batch_size:
309 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
310
311 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800312 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700313
314 # Generate a random OFM depth
315 ofm_depth = testGen.makeShape(1)[0]
316
317 # The filter dimensions are OHWI
318 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
319
320 # The bias is OC
321 bias_shape = np.asarray([ofm_depth])
322
323 return [ifm_shape, filter_shape, bias_shape]
324
325 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100326 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700327 pl, const = op["operands"]
328
329 assert rank == 5
330
331 # IFM dimensions are NDHWC
332 ifm_shape = testGen.makeShape(rank)
333
334 # Constrict the batch size?
335 if testGen.args.max_batch_size:
336 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
337
338 # Get the filter depth/height/width from the operator parameters
339 filter_dhw = op["filter"]
340
341 # Generate a random OFM channel
342 ofm_channel = testGen.makeShape(1)[0]
343
344 # The filter dimensions are ODHWI
345 filter_shape = np.asarray(
346 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
347 )
348
349 # The bias is OC
350 bias_shape = np.asarray([ofm_channel])
351
352 return [ifm_shape, filter_shape, bias_shape]
353
354 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100355 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800356 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700357
Kevin Cheng550ccc52021-03-03 11:21:43 -0800358 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700359
360 # IFM dimensions are NHWC
361 ifm_shape = testGen.makeShape(rank)
362
363 # Constrict the batch size?
364 if testGen.args.max_batch_size:
365 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
366
367 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800368 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700369
370 # Generate a random OFM depth
371 ofm_depth = testGen.makeShape(1)[0]
372
373 # The filter dimensions are OHWI
374 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
375
Kevin Cheng989cb052021-04-28 16:29:44 -0700376 # The bias is OC
377 bias_shape = np.asarray([ofm_depth])
378
379 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700380
381 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100382 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800383 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700384
Kevin Cheng550ccc52021-03-03 11:21:43 -0800385 assert rank == 4
386 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700387
388 # IFM dimensions are NHWC
389 ifm_shape = testGen.makeShape(rank)
390
391 # Constrict the batch size?
392 if testGen.args.max_batch_size:
393 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
394
395 # Get the filter height/width from the operator parameters
396 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800397 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700398
399 # Generate a random OFM depth, but don't let it get too big because
400 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800401 filter_m = (
402 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
403 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700404
405 # The filter dimensions are HWCM
406 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
407
408 # The bias is M * C
409 bias_shape = np.asarray([ifm_shape[3] * filter_m])
410
411 return [ifm_shape, filter_shape, bias_shape]
412
413 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100414 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800415 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700416
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100417 if error_name != ErrorIf.WrongRank:
418 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700419
420 input_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100421
Matthew Haddon630c17c2021-10-14 15:05:41 +0100422 # Constrict the overall size of the shape when creating ERROR_IF tests
423 if error_name:
424 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100425
Kevin Chengacb550f2021-06-29 15:32:19 -0700426 filter_oc = testGen.rng.integers(
427 low=testGen.args.tensor_shape_range[0],
428 high=testGen.args.tensor_shape_range[1],
429 size=1,
430 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700431 filter_shape = np.asarray([filter_oc, input_shape[1]])
432
433 bias_shape = np.asarray([filter_oc])
434
435 return [input_shape, filter_shape, bias_shape]
436
437 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100438 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800439 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700440
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100441 if error_name != ErrorIf.WrongRank:
442 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800443 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700444
445 a_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100446
Matthew Haddon630c17c2021-10-14 15:05:41 +0100447 # Constrict the overall size of the shape when creating ERROR_IF tests
448 if error_name:
449 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100450
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100451 # Get a random number for b_oc even if target shape is defined
452 b_oc = np.int32(
453 testGen.rng.integers(
454 low=testGen.args.tensor_shape_range[0],
455 high=testGen.args.tensor_shape_range[1],
456 size=1,
457 )
458 )[0]
459 # If N or H is large let b_oc be 1 to reduce output tensor size
460 if max(a_shape) > 1000:
461 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700462
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100463 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700464 return [a_shape, b_shape]
465
Matthew Haddon818ab902021-07-27 09:12:49 +0100466 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100467 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100468 pl, const = opName["operands"]
469 shape = testGen.makeShape(rank)
470
471 # Create extra tensors to concat.
472 # Take into account value of pl when getting maximum number of concats
473 num_tensors = testGen.randInt(0, 4)
474 shape_list = []
475 for i in range(pl + const + num_tensors):
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100476 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
477 remove = testGen.rng.choice([True, False])
478 wrongShape = shape.copy()
479
480 if remove and len(shape) > 1:
481 wrongShape = wrongShape[1:]
482 else:
483 wrongShape = list(wrongShape)
484 wrongShape.append(testGen.rng.integers(1, 10))
485
486 shape_list.append(wrongShape)
487 else:
488 shape_list.append(shape.copy())
Matthew Haddon818ab902021-07-27 09:12:49 +0100489
490 return shape_list
491
492 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100493 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100494 if error_name in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ConcatInputRankMismatch]:
495 return shapeList
496
Matthew Haddon818ab902021-07-27 09:12:49 +0100497 # Split concat shape along axis to allow for multiple const inputs
498 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100499 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100500 # If axis can't be split we still need to invalidate other dimensions
501 if error_name == ErrorIf.ConcatInputDimMismatch:
502 for shape in shapeList[1:]:
503 # Negative test shapeLists are created individually for each test,
504 # so no need to copy the shape before altering it.
505 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
Matthew Haddon818ab902021-07-27 09:12:49 +0100506 return shapeList
507
Jeremy Johnson960985a2021-10-06 10:58:14 +0100508 # Create copy of shape we are going to split (so we don't alter shapeList)
509 shape = shapeList[0].copy()
510 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100511 new_shapeList = [shape.copy()]
512 length_on_axis = shape[axis]
513 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700514 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100515 # Calculate split on axis and remaining value
516 split_shape_val = int(shape[axis] / 2)
517 remaining_length = remaining_length - split_shape_val
518
519 # Append new shape, and set remaining shape
520 shape[axis] = split_shape_val
521 new_shapeList.append(shape.copy())
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100522
523 # invalidate dimensions
524 if error_name == ErrorIf.ConcatInputDimMismatch:
525 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
526 else:
527 shape[axis] = remaining_length
528
Matthew Haddon818ab902021-07-27 09:12:49 +0100529 if i == len(shapeList) - 3:
530 new_shapeList.append(shape.copy())
531
532 return new_shapeList
533
534
Eric Kunzee5e26762020-10-13 16:11:07 -0700535class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800536 """Argument generators create exhaustive or random lists of attributes for operators that take
537 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
538 tuples where the descriptive_name is appended to the test name and the arglist is expanded
539 as arguments to the operator build function."""
540
Eric Kunzee5e26762020-10-13 16:11:07 -0700541 def __init__(self):
542 pass
543
544 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100545 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800546 """A trivial argument generator for operators that don't take any
547 non-tensor arguments"""
548 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700549
550 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100551 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800552 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700553 axes = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700554 shape = shapeList[0]
555
Matthew Haddond6ce7252021-09-29 15:35:44 +0100556 if error_name == ErrorIf.AxisSmallerZero:
557 small_axis = testGen.rng.integers(-5, 0)
558 axes.append(("axis{}".format(small_axis), [small_axis]))
559 elif error_name == ErrorIf.AxisLargerRank:
560 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
561 axes.append(("axis{}".format(large_axis), [large_axis]))
562 else:
563 for a in range(0, len(shape)):
564 axes.append(("axis{}".format(a), [a]))
565
Eric Kunzee5e26762020-10-13 16:11:07 -0700566 return axes
567
568 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100569 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700570 arg_list = []
571
572 ifm_shape = shapeList[0]
573 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100574 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
575 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700576
Les Bell7aa69f42021-09-20 10:44:07 +0100577 # Check the rank
578 rank = 5 if opName.startswith("conv3d") else 4
579 assert len(ifm_shape) == rank
580 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700581
Les Bell7aa69f42021-09-20 10:44:07 +0100582 # kernel rank omits batch and channels
583 k_rank = rank - 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700584
Les Bell7aa69f42021-09-20 10:44:07 +0100585 # Generate comprehensive argument lists
586 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
587 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
588 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
589 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
590 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
591 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700592
Les Bell7aa69f42021-09-20 10:44:07 +0100593 # add some oversize argument values
594 if max(ifm_shape) < 64:
595 bigPadding = 9
596 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
597 bigStride = 8
598 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
599 bigDilation = 7
600 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100601
602 # There are too many parameter combinations, so generate them sparsely
Les Bell7aa69f42021-09-20 10:44:07 +0100603 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
604 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
605 if sparsity < 13:
606 sparsity = 1
607 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
608 sparsity += 1
Les Bellf414b3c2021-09-06 11:29:46 +0100609 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100610 for s in sorted(list(strides)):
611 for p in sorted(list(paddings)):
612 for d in sorted(list(dilations)):
613 if (n % sparsity == 0
614 # padding must not exceed the kernel size ?
615 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
616 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
617 # the padded shape must exceed the kernel size
618 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
619 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
620 # the padded shape must exceed the dilation
621 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
622 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
623 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100624 arg_list.append(
625 (
626 "st{}_pad{}_dilat{}".format(
627 "".join([str(x) for x in s]),
628 "".join([str(x) for x in p]),
629 "".join([str(x) for x in d]),
630 ),
631 [s, p, d],
632 )
633 )
634 n += 1
635
Kevin Cheng1533b852021-09-01 12:51:58 -0700636 return arg_list
637
638 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100639 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700640 arg_list = []
641
642 ifm_shape = shapeList[0]
643 filter_shape = shapeList[1]
644
645 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800646 assert len(ifm_shape) == 4
647 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700648
Les Bell7aa69f42021-09-20 10:44:07 +0100649 # Generate comprehensive argument lists
650 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
651 paddings = {x for x in itertools.product(*([p_vals] * 2))}
652 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
653 strides = {x for x in itertools.product(*([s_vals] * 2))}
654 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
655 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700656
Les Bell7aa69f42021-09-20 10:44:07 +0100657 # add some oversize argument values
658 if max(ifm_shape) < 64:
659 bigPadding = 9
660 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
661 bigStride = 8
662 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
663 bigDilation = 7
664 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700665
Les Bell7aa69f42021-09-20 10:44:07 +0100666 # There are too many parameter combinations, so generate them sparsely
667 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
668 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
669 if sparsity < 13:
670 sparsity = 1
671 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
672 sparsity += 1
673 n = 0
674 for s in sorted(list(strides)):
675 for p in sorted(list(paddings)):
676 for d in sorted(list(dilations)):
677 if n % sparsity == 0:
678 # Determine the output shape
679 oh = (
680 ifm_shape[1]
681 - filter_shape[1]
682 - (filter_shape[1] - 1) * (d[0] - 1)
683 + 2 * p[0]
684 ) // s[0] + 1
685 ow = (
686 ifm_shape[2]
687 - filter_shape[2]
688 - (filter_shape[2] - 1) * (d[1] - 1)
689 + 2 * p[1]
690 ) // s[1] + 1
691 os = [ifm_shape[0], oh, ow, filter_shape[0]]
692 arg_list.append(
693 (
694 "st{}_pad{}_dilat{}_os{}".format(
695 "".join([str(x) for x in s]),
696 "".join([str(x) for x in p]),
697 "".join([str(x) for x in d]),
698 "x".join([str(x) for x in os]),
699 ),
700 [s, p, d, os],
701 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800702 )
Les Bell7aa69f42021-09-20 10:44:07 +0100703 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700704
705 return arg_list
706
707 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100708 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700709 arg_list = []
710 rank = len(shapeList[0])
711
Les Bell7ffccce2021-07-28 15:37:02 +0100712 # Exhaustively test combinations of padding on each side of each dimension
713 # - the range of padding values is defined by pad_min and pad_max
714 # - for padding >9, the name format needs to be more distinctive
715 pad_min, pad_max = 0, 1
716 pad_values = [x for x in range(pad_min, pad_max + 1)]
Matthew Haddone807aae2021-10-11 18:12:58 +0100717 if error_name == ErrorIf.PadSmallerZero:
718 pad_values = [x for x in range(-2, 0)]
Les Bell7ffccce2021-07-28 15:37:02 +0100719 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
720 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700721
Kevin Chengfe392ce2021-10-18 21:51:55 +0000722 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
723 pad_const_int = testGen.getRandNumberDType(dtype)
724 pad_const_fp = 0
725 elif dtype == DType.FLOAT:
726 pad_const_int = 0
727 pad_const_fp = testGen.getRandNumberDType(dtype)
728 else:
729 return []
730
Les Bell7ffccce2021-07-28 15:37:02 +0100731 for paddings in shape_pad_values:
732 name = "pad"
733 for r in range(rank):
734 before, after = paddings[r]
735 name = f"{name}{before}{after}"
Kevin Chengfe392ce2021-10-18 21:51:55 +0000736 arg_list.append((name, [np.array(paddings), pad_const_int, pad_const_fp]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700737
738 return arg_list
739
740 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100741 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700742 arg_list = []
743
744 shape = shapeList[0]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100745 if error_name != ErrorIf.WrongRank:
746 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700747
Les Bell7aa69f42021-09-20 10:44:07 +0100748 # Generate comprehensive argument lists
749 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
750 paddings = {x for x in itertools.product(*([p_vals] * 4))}
751 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
752 strides = {x for x in itertools.product(*([s_vals] * 2))}
753 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
754 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700755
Les Bell7aa69f42021-09-20 10:44:07 +0100756 # add some oversize argument values
757 bigStride = 7
758 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
759 bigKernel = 6
760 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
761 if max(shape) < 64:
762 # padding must be less than the kernel size
763 bigPadding = bigKernel - 1
764 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700765
Les Bell7aa69f42021-09-20 10:44:07 +0100766 # There are too many parameter combinations, so generate them sparsely
767 sparsity = len(paddings) * len(strides) * len(kernels) // 500 + 1
768 n = 0
769 for s in sorted(list(strides)):
770 for p in sorted(list(paddings)):
771 for k in sorted(list(kernels)):
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100772 if error_name in [ErrorIf.StrideSmallerOne, ErrorIf.KernelSmallerOne, ErrorIf.PadSmallerZero, ErrorIf.PadLargerEqualKernel]:
773 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(testGen, error_name, s, p, k)
774 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
775 arg_list.append(
776 (
777 "st{}_kern{}_pad{}".format(
778 "".join([str(x) for x in sNew]),
779 "".join([str(x) for x in kNew]),
780 "".join([str(x) for x in pNew]),
781 ),
782 [sNew, pNew, kNew],
783 )
784 )
785 elif (n % sparsity == 0
Les Bell7aa69f42021-09-20 10:44:07 +0100786 # padding must not exceed the kernel size
787 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
788 # the padded shape must exceed the kernel size
789 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
790 ):
791 arg_list.append(
792 (
793 "st{}_kern{}_pad{}".format(
794 "".join([str(x) for x in s]),
795 "".join([str(x) for x in k]),
796 "".join([str(x) for x in p]),
797 ),
798 [s, p, k],
799 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800800 )
Les Bell7aa69f42021-09-20 10:44:07 +0100801 n += 1
802
Eric Kunzee5e26762020-10-13 16:11:07 -0700803 return arg_list
804
805 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100806 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700807 arg_list = []
808
809 # Enumerate the output types here
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100810 if error_name == ErrorIf.WrongOutputType:
811 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
812 elif inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800813 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700814 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800815 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700816 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800817 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700818 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800819 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700820 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800821 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100822 elif error_name == ErrorIf.WrongInputType:
823 # Pick some potentially correct output type for incorrect input type
824 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700825 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800826 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700827
828 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800829 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700830
831 return arg_list
832
833 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100834 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700835 arg_list = []
836
837 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100838 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100839 if dtype in [DType.UINT8, DType.INT8] and error_name == ErrorIf.OutputZeroPointNotZero:
840 continue
841 if inDtype == DType.UINT8 and dtype != DType.INT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100842 # The only output dtype for UINT8 is INT8, skip all other combinations
843 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100844 if inDtype != DType.INT8 and dtype == DType.UINT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100845 # The only input dtype for UINT8 is INT8, skip all other combinations
846 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100847 if error_name == ErrorIf.WrongOutputType and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype):
848 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100849
Kevin Cheng550ccc52021-03-03 11:21:43 -0800850 for scale32 in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100851 if error_name == ErrorIf.ScaleTrue and scale32 == False:
852 continue
853 elif error_name == ErrorIf.ScaleNotTrue and scale32 == True:
854 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800855 for double_round in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100856 if error_name == ErrorIf.ScaleNotTrue and double_round == False:
857 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800858 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700859
Matthew Haddonc2025212021-10-08 21:21:05 +0100860 if inDtype == DType.INT48 and scale32 and error_name != ErrorIf.ScaleTrue:
Eric Kunzee5e26762020-10-13 16:11:07 -0700861 # Illegal condition. Must be scale32=False
862 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100863 if double_round and not scale32 and error_name != ErrorIf.ScaleNotTrue:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100864 # Illegal condition. ERROR_IF(!scale32 && double_round)
865 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700866
Kevin Cheng550ccc52021-03-03 11:21:43 -0800867 arg_list.append(
868 (
869 "out{}_sc{}_dr{}_pc{}".format(
870 DTypeNames[dtype],
871 int(scale32),
872 int(double_round),
873 int(per_channel),
874 ),
875 [dtype, scale32, double_round, per_channel],
876 )
877 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700878
879 return arg_list
880
Kevin Chengaee1fac2020-11-11 13:54:06 -0800881 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100882 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800883 arg_list = []
884
885 if dtype is DType.INT32:
886 for p in range(testGen.args.num_rand_permutations):
887
888 shift = testGen.randInt(0, 32)
889
Kevin Cheng550ccc52021-03-03 11:21:43 -0800890 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800891 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100892 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800893
894 return arg_list
895
896 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100897 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800898 arg_list = []
899
Kevin Cheng550ccc52021-03-03 11:21:43 -0800900 arg_list.append(("roundTrue", [True]))
901 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800902
903 return arg_list
904
Eric Kunzee5e26762020-10-13 16:11:07 -0700905 # Helper function for reshape. Gets some factors of a larger number.
906 @staticmethod
907 def getFactors(val, start=1):
908 factors = []
909
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100910 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700911 if (val % i) == 0:
912 factors.append(i)
913
914 return factors
915
916 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100917 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700918 arg_list = []
919
920 origShape = shapeList[0]
921
922 totalElements = 1
923 for s in origShape:
924 totalElements *= s
925
926 # This code is NOT fast. Fortunately, the numbers are fairly small.
927 factors = TosaArgGen.getFactors(totalElements)
928
929 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100930 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800931 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700932 continue
933
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100934 found = True
935 # escape_counter breaks while loop if it continues on for too long
936 escape_counter = 0
937 while found:
938 newShape = []
939 # Generate newShape ensuring it isn't a duplicate
940 remainingElements = totalElements
941 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100942 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100943 # pick rank-1 factors
944 newShape.append(shuffledFactors[0])
945 remainingElements = remainingElements // shuffledFactors[0]
946 shuffledFactors = testGen.rng.permutation(
947 TosaArgGen.getFactors(remainingElements)
948 )
949 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700950
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100951 # Toss in a -1 sometimes
952 minusOne = testGen.randInt(0, newRank * 4)
953 if minusOne < newRank:
954 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700955
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100956 # Check for duplicates
957 found = False
958 for name, other_shape in arg_list:
959 if other_shape[0] == newShape:
960 found = True
961 break
962
963 escape_counter += 1
964 if escape_counter >= 100:
965 break
966
967 if not found:
968 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700969
970 return arg_list
971
Eric Kunzee5e26762020-10-13 16:11:07 -0700972 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100973 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700974 arg_list = []
975
976 ifm_shape = shapeList[0]
977
Matthew Haddone807aae2021-10-11 18:12:58 +0100978
979 if error_name == ErrorIf.IndexOutsideBounds:
980 incorrect_large_index = range(len(ifm_shape)+1, 2*len(ifm_shape)+1)
981 incorrect_small_index = range(-len(ifm_shape), 0)
982 permutations = [p for p in itertools.permutations(incorrect_large_index)]
983 permutations.extend([p for p in itertools.permutations(incorrect_small_index)])
984 elif error_name == ErrorIf.IndexUsedTwice:
985 # Create list with a duplicated index
986 perm_range = list(range(len(ifm_shape)))
987 index_choice = testGen.rng.choice(range(len(perm_range)))
988 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
989 permutations = [p for p in itertools.permutations(perm_range)]
990
991
992 else:
993 # Get all permutations
994 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700995
Jeremy Johnsona6185572021-06-21 15:55:35 +0100996 # Limit to possible permutations from shape dimension or argument setting
997 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700998
Jeremy Johnsona6185572021-06-21 15:55:35 +0100999 # Get random permutation generator that uses all permutations
1000 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001001
Jeremy Johnsona6185572021-06-21 15:55:35 +01001002 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -07001003 arg_list = [
1004 ("perm{}".format(p), [random_permutations[p].tolist()])
1005 for p in range(limit)
1006 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07001007 return arg_list
1008
1009 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001010 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001011 arg_list = []
1012
1013 ifm_shape = shapeList[0]
1014 rank = len(ifm_shape)
1015
1016 for p in range(testGen.args.num_rand_permutations):
Matthew Haddone807aae2021-10-11 18:12:58 +01001017 start = []
Eric Kunzee5e26762020-10-13 16:11:07 -07001018 size = []
1019
Kevin Cheng550ccc52021-03-03 11:21:43 -08001020 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -07001021
1022 for i in range(rank):
1023 if ifm_shape[i] > 1:
Matthew Haddone807aae2021-10-11 18:12:58 +01001024 start.append(testGen.randInt(0, ifm_shape[i]))
1025 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001026
1027 # Invalid slice size?
1028 if size[i] == 0:
1029 valid = False
1030 else:
Matthew Haddone807aae2021-10-11 18:12:58 +01001031 start.append(0)
Eric Kunzee5e26762020-10-13 16:11:07 -07001032 size.append(1)
1033
1034 if valid:
Matthew Haddone807aae2021-10-11 18:12:58 +01001035 # If ERROR_IF test required then incorrect start, size will be returned
1036 start, size = TosaErrorIfArgGen.eiSliceErrorIf(testGen, error_name, ifm_shape, start, size)
1037 arg_list.append(("perm{}".format(p), [start, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001038 return arg_list
1039
1040 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001041 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001042 arg_list = []
1043
1044 ifm_shape = shapeList[0]
1045 rank = len(ifm_shape)
1046
1047 for p in range(testGen.args.num_rand_permutations):
1048
1049 # Pick a few random, but small multiple values
1050 # because otherwise this has a tendency to generate
1051 # enormous tensors
1052 multiples = []
1053 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +01001054 if ifm_shape[i] > 1000:
1055 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
1056 multiples.append(1)
1057 elif max(ifm_shape) > 1000:
1058 multiples.append(2)
1059 else:
1060 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001061 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001062
1063 return arg_list
1064
1065 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001066 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001067 arg_list = []
1068
1069 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +01001070 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001071
1072 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +01001073 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001074 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +01001075 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001076 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +01001077 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001078 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +01001079 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001080 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001081 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001082 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +01001083 elif error_name == ErrorIf.WrongInputType:
1084 # If an incorrect input type is used then we set a 'correct'
1085 # output type to avoid other errors
1086 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -07001087 else:
1088 continue
1089
1090 for outputDType in outputDTypeList:
1091 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -07001092 # Randomly generate legal output dimensions and shift
1093 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +01001094 # A output_dim of 1 will cause offset to exceed allowed range
1095 # so minimum value 2 produced below
1096 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
1097 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
1098 output_dims[0] += 1
1099 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
1100 output_dims[1] += 1
1101
Kevin Cheng77d0f762020-11-24 10:26:32 -08001102 in_center_h = (ifm_shape[1] - 1) / 2.0
1103 in_center_w = (ifm_shape[2] - 1) / 2.0
1104 out_center_h = (output_dims[0] - 1) / 2.0
1105 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -07001106
Kevin Cheng77d0f762020-11-24 10:26:32 -08001107 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
1108 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
1109 fp_offset_y = in_center_h - fp_stride_y * out_center_h
1110 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -07001111
Kevin Cheng77d0f762020-11-24 10:26:32 -08001112 if outputDType == DType.FLOAT:
1113 shift = 0
1114 stride = [0, 0]
1115 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001116 stride_fp = [fp_stride_y, fp_stride_x]
1117 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +01001118
1119 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001120 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001121 testGen,
1122 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001123 mode,
1124 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001125 shapeList,
1126 outputDType,
1127 shift,
1128 stride,
1129 stride_fp,
1130 offset,
1131 offset_fp
1132 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001133 else:
1134 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001135
Kevin Cheng550ccc52021-03-03 11:21:43 -08001136 arg_list.append(
1137 (
1138 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001139 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001140 output_dims[0],
1141 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001142 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001143 stride_fp[0],
1144 stride_fp[1],
1145 offset_fp[0],
1146 offset_fp[1],
1147 ),
1148 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001149 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001150 stride,
1151 offset,
1152 shift,
1153 stride_fp,
1154 offset_fp,
1155 output_dims,
1156 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001157 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001158 ],
1159 )
1160 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001161 else:
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001162 shift = testGen.randInt(1,12)
1163 # Now search for a shift value (1 to 11) that will produce
1164 # a valid and predictable resize operation
1165 count = 0
1166 while (count < 12):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001167 unit = float(1 << shift)
1168 stride_y = int(round(fp_stride_y * unit))
1169 stride_x = int(round(fp_stride_x * unit))
1170 offset_y = int(round(fp_offset_y * unit))
1171 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001172
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001173 if (
1174 stride_y >= (16 << shift)
1175 or stride_x >= (16 << shift)
1176 or offset_y >= (16 << shift)
1177 or offset_x >= (16 << shift)
1178 or offset_y <= (-16 << shift)
1179 or offset_x <= (-16 << shift)
1180 ):
1181 # Change the shift value and check again
1182 count += 1
1183 shift = (shift % 11) + 1
1184 continue
1185
1186 def RESIZE_REQUIRE_CALC(length_in, length_out, stride, offset, shift):
1187 # Perform the pseudo loop to look for out of bounds
1188 for pos in range(0,length_out):
1189 a = pos * stride + offset
1190 ia = a >> shift
1191 ia0 = max(ia, 0)
1192 ia1 = min(ia+1, length_in-1)
1193 if ia0 > ia1:
1194 # Found a problem value
1195 break
1196 return ia0, ia1
1197
1198 iy0, iy1 = RESIZE_REQUIRE_CALC(ifm_shape[1], output_dims[0], stride_y, offset_y, shift)
1199 ix0, ix1 = RESIZE_REQUIRE_CALC(ifm_shape[2], output_dims[1], stride_x, offset_x, shift)
1200 if ix0 > ix1 or iy0 > iy1:
1201 # Change the shift value and check again
1202 count += 1
1203 shift = (shift % 11) + 1
1204 continue
1205 break
1206
1207 if count >= 12:
1208 # Couldn't find a good set of values for this test, skip it
1209 continue
1210
Kevin Cheng550ccc52021-03-03 11:21:43 -08001211 stride = [stride_y, stride_x]
1212 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001213
1214 stride_fp = [0.0, 0.0]
1215 offset_fp = [0.0, 0.0]
1216
Matthew Haddone86fd342021-09-07 16:12:21 +01001217 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001218 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001219 testGen,
1220 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001221 mode,
1222 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001223 shapeList,
1224 outputDType,
1225 shift,
1226 stride,
1227 stride_fp,
1228 offset,
1229 offset_fp
1230 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001231 else:
1232 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001233
Kevin Cheng550ccc52021-03-03 11:21:43 -08001234 arg_list.append(
1235 (
1236 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001237 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001238 shift,
1239 output_dims[0],
1240 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001241 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001242 stride[0],
1243 stride[1],
1244 offset[0],
1245 offset[1],
1246 ),
1247 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001248 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001249 stride,
1250 offset,
1251 shift,
1252 stride_fp,
1253 offset_fp,
1254 output_dims,
1255 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001256 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001257 ],
1258 )
1259 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001260
1261 return arg_list
1262
Kevin Chengfe392ce2021-10-18 21:51:55 +00001263 @staticmethod
1264 def agTable(testGen, opName, shapeList, dtype, error_name=None):
1265 arg_list = []
1266
1267 if dtype == DType.INT8:
1268 table = np.int32(
1269 testGen.rng.integers(low=-128, high=128, size=[256])
1270 ).tolist()
1271 else: # INT16
1272 table = np.int32(
1273 testGen.rng.integers(low=-32768, high=32768, size=[513])
1274 ).tolist()
1275
1276 arg_list.append(
1277 (
1278 "",
1279 [table],
1280 )
1281 )
1282 return arg_list
1283
Matthew Haddon1c00b712021-10-01 15:51:03 +01001284 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001285 # CondIf generates the condition values here.
1286 # Convert to tensors in the build function, along with the
1287 # then and else blocks
1288 arg_list = []
1289
1290 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001291 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001292
1293 return arg_list
1294
Matthew Haddon1c00b712021-10-01 15:51:03 +01001295 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001296 # While loop: 0 iterations, 1, more than 1
1297 arg_list = []
1298
1299 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001300 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001301
1302 return arg_list
1303
Matthew Haddone86fd342021-09-07 16:12:21 +01001304class TosaErrorIfArgGen:
1305
1306 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001307 def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
Matthew Haddone86fd342021-09-07 16:12:21 +01001308
1309 if outputDType == DType.FLOAT:
1310 if error_name == ErrorIf.StrideSmallerEqualZero:
1311 stride_fp = testGen.rng.random(size=[2]) - 2
1312 elif error_name == ErrorIf.ShiftNotZero:
1313 shift = testGen.rng.integers(1, 5)
1314 elif error_name == ErrorIf.StrideLargerDimension:
1315 shape = shapeList[0]
1316 transform_height = testGen.rng.choice([False, True])
1317 if transform_height:
1318 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1319 else:
1320 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1321 else:
1322 if error_name == ErrorIf.StrideSmallerEqualZero:
1323 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1324 elif error_name == ErrorIf.ShiftSmallerOne:
1325 shift = testGen.rng.integers(-3, 1)
1326 if shift <= 0:
1327 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1328 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1329 else:
1330 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1331 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1332 elif error_name == ErrorIf.ShiftLargerEleven:
1333 shift = np.int16(testGen.rng.integers(12, 15))
1334 elif error_name == ErrorIf.StrideLargerDimension:
1335 shape = shapeList[0]
1336 transform_height = testGen.rng.choice([False, True])
1337 if transform_height:
1338 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1339 else:
1340 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1341 elif error_name == ErrorIf.StrideLargerEqualMax:
1342 stride = [(16 << shift) + 1, (16 << shift) + 1]
1343 elif error_name == ErrorIf.OffsetLargerEqualMax:
1344 offset = [(16 << shift) + 1, (16 << shift) + 1]
1345 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1346 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1347
Matthew Haddon1c00b712021-10-01 15:51:03 +01001348
Matthew Haddon848efb42021-09-09 12:30:53 +01001349 if error_name == ErrorIf.WrongOutputType:
1350 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1351 incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1352 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1353 incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
1354 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1355 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
1356 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1357 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
1358 elif dtype == DType.FLOAT:
1359 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
1360 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001361
Matthew Haddon848efb42021-09-09 12:30:53 +01001362 return shift, stride, stride_fp, offset, offset_fp, outputDType
1363
Matthew Haddone807aae2021-10-11 18:12:58 +01001364
Matthew Haddon848efb42021-09-09 12:30:53 +01001365 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001366 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
1367 if (error_name == ErrorIf.StrideSmallerOne
1368 # padding must not exceed the kernel size
1369 and pad[0] < kernel[0] and pad[1] < kernel[0] and pad[2] < kernel[1] and pad[3] < kernel[1]):
1370 wrongStride = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1371 return wrongStride, pad, kernel
1372 elif error_name == ErrorIf.PadSmallerZero:
1373 wrongPad = (testGen.rng.choice([-1, -2, -3]),
1374 testGen.rng.choice([-1, -2, -3]),
1375 testGen.rng.choice([-1, -2, -3]),
1376 testGen.rng.choice([-1, -2, -3]))
1377 return stride, wrongPad, kernel
1378 elif error_name == ErrorIf.KernelSmallerOne:
1379 wrongKernel = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1380 return stride, pad, wrongKernel
1381 elif error_name == ErrorIf.PadLargerEqualKernel:
1382 wrongPad = (testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1383 testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1384 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]),
1385 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]))
1386 return stride, wrongPad, kernel
1387 else:
1388 return None, None, None
1389
Matthew Haddone807aae2021-10-11 18:12:58 +01001390
Matthew Haddonc2025212021-10-08 21:21:05 +01001391 @staticmethod
1392 def eiRescaleWrongOutputType(input_dtype, output_dtype):
1393 if input_dtype == DType.INT8:
1394 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1395 return True
1396 if input_dtype in [DType.INT16, DType.INT32]:
1397 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1398 return True
1399 elif input_dtype == DType.INT48:
1400 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1401 return True
1402 elif input_dtype == DType.UINT8:
1403 if output_dtype != DType.INT8:
1404 return True
1405 return False
1406
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001407
1408 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001409 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1410 # Mess up input/output tensors for ERROR_IF checks
1411 if error_name == "WrongInputList":
1412 add_input = testGen.rng.choice([True, False])
1413 if add_input:
1414 input_list.append('eiDummyInput')
1415 else:
1416 input_list = input_list[:-1]
1417 if error_name == "WrongOutputList":
1418 add_output = testGen.rng.choice([True, False])
1419 if add_output:
1420 output_list.append('eiDummyOutput')
1421 else:
1422 output_list = []
1423 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001424
Matthew Haddonc2025212021-10-08 21:21:05 +01001425 @staticmethod
Matthew Haddon630c17c2021-10-14 15:05:41 +01001426 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
1427 """Restrict the dimensions and overall size of a shape to max_dim and max_items."""
1428 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
1429 while product(new_shape) > max_items:
1430 new_shape = [max(d - 1, 1) for d in new_shape]
1431 return new_shape
Matthew Haddone807aae2021-10-11 18:12:58 +01001432
1433 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
1434 if error_name == ErrorIf.StartSmallerZero:
1435 newStart = []
1436 for i in range(len(input_shape)):
1437 newStart.append(testGen.rng.choice([-3, -2, -1]))
1438 return newStart, size
1439 elif error_name == ErrorIf.SizeSmallerEqualZero:
1440 newSize = []
1441 for i in range(len(input_shape)):
1442 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
1443 return start, newSize
1444 elif error_name == ErrorIf.StartSizeOutsideBounds:
1445 newStart, newSize = [], []
1446 for i in range(len(input_shape)):
1447 newStart.append(input_shape[i]-1)
1448 newSize.append(testGen.rng.choice([2, 3, 4]))
1449 return newStart, newSize
1450 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
1451 remove = testGen.rng.choice([True, False])
1452 if remove:
1453 newStart = start[1:]
1454 newSize = size[1:]
1455 else:
1456 newStart = start
1457 newStart.append(1)
1458 newSize = size
1459 newSize.append(1)
1460 return newStart, newSize
1461 else:
1462 return start, size
1463
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001464 @staticmethod
1465 def eiCastErrorIf(testGen, input_dtype):
1466 if input_dtype in [DType.BOOL, DType.FLOAT]:
1467 outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
1468 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
1469 outputDType = [DType.INT48]
1470 else:
1471 assert True, f"input_dtype ({input_dtype}) not supported"
1472 return outputDType
1473
1474
Matthew Haddone86fd342021-09-07 16:12:21 +01001475class TosaErrorValidator:
1476
Matthew Haddon848efb42021-09-09 12:30:53 +01001477 @staticmethod
1478 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
1479 # Check ERROR_IF statements
1480
1481 for val_fcn in validator_fcns:
1482 val_result = val_fcn(True, **kwargs)
1483
1484 validator_name = val_result['error_name']
1485 error_result = val_result['error_result']
1486 error_reason = val_result['error_reason']
1487
1488 if error_result:
1489 if error_name == validator_name:
1490 serializer.setExpectedReturnCode(2, error_reason)
1491 else:
1492 print(f"Multiple ERROR_IF checks hit \nError required: {error_name}, Error_produced: {validator_name}")
1493 return None # Return None to delete test if wrong ERROR_IF is hit
1494 else:
1495 if error_name == validator_name:
1496 print(f"No ERROR_IF hit for {error_name}")
1497 return None
1498
1499 @staticmethod
1500 def evWrongInputType(check=False, **kwargs):
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001501 all_dtypes = {DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
Matthew Haddon848efb42021-09-09 12:30:53 +01001502
1503 # Find the unsupported input data types
1504 assert 'op' in kwargs
1505 op = kwargs['op']
1506 input_dtypes = op['types']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001507
1508 allowed_input_dtypes = {t[0] if isinstance(t, list) else t for t in input_dtypes}
1509 wrong_input_dtypes = list(all_dtypes - allowed_input_dtypes)
Matthew Haddon848efb42021-09-09 12:30:53 +01001510
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001511 if op['op'] == Op.CLAMP:
1512 wrong_input_dtypes.remove(DType.INT48)
1513
Matthew Haddon848efb42021-09-09 12:30:53 +01001514 error_name = ErrorIf.WrongInputType
1515 param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
1516 error_result = False
1517 error_reason = "Input data type not supported for this operator"
1518
1519 if check:
1520 input_dtype = kwargs['input_dtype']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001521 if op['op'] == Op.FULLY_CONNECTED:
1522 if input_dtype not in allowed_input_dtypes:
1523 error_result = True
1524 elif input_dtype not in input_dtypes:
Matthew Haddon848efb42021-09-09 12:30:53 +01001525 error_result = True
1526
1527 info_dict = {
1528 "error_name": error_name,
1529 "error_result": error_result,
1530 "error_reason": error_reason,
1531 "param_reqs": param_reqs
1532 }
1533 return info_dict
1534
1535 @staticmethod
1536 def evWrongOutputType(check=False, **kwargs):
1537 error_name = ErrorIf.WrongOutputType
1538 param_reqs = {"rank": None, "dtype": None, "shape": None}
1539 error_result = False
1540 error_reason = "Output data type not supported for this configuration of operator"
1541
1542 if check:
1543 input_dtype = kwargs['input_dtype']
1544 output_dtype = kwargs['output_dtype']
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001545 op = kwargs['op']
Matthew Haddon848efb42021-09-09 12:30:53 +01001546
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001547 if op['op'] == Op.RESIZE:
1548 mode = kwargs['mode']
1549 if (
1550 (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
1551 (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
1552 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1553 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1554 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1555 ):
1556 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001557
Matthew Haddonc2025212021-10-08 21:21:05 +01001558 elif op['op'] == Op.RESCALE:
1559 if input_dtype == DType.INT8:
1560 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1561 error_result = True
1562 if input_dtype in [DType.INT16, DType.INT32]:
1563 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1564 error_result = True
1565 elif input_dtype == DType.INT48:
1566 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1567 error_result = True
1568 elif input_dtype == DType.UINT8:
1569 if output_dtype != DType.INT8:
1570 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001571
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001572 elif op['op'] in [Op.FULLY_CONNECTED, Op.MATMUL]:
1573 if (
1574 (input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1575 (input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1576 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1577 ):
1578 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001579
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001580 elif op['op'] == Op.ARGMAX:
1581 if input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] and output_dtype != DType.INT32:
1582 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001583
1584 elif op['op'] == Op.MUL:
1585 if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
1586 error_result = True
1587 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1588 error_result = True
1589
1590 elif op['op'] == Op.TABLE:
1591 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
1592 error_result = True
1593 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
1594 error_result = True
1595
1596 elif op['op'] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
1597 if output_dtype != DType.BOOL:
1598 error_result = True
1599
1600 elif op['op'] == Op.CAST:
1601 if (
1602 (input_dtype == DType.BOOL and output_dtype not in [DType.INT8, DType.INT16, DType.INT32])
1603 or (input_dtype == DType.INT8 and output_dtype not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT])
1604 or (input_dtype == DType.INT16 and output_dtype not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT])
1605 or (input_dtype == DType.INT32 and output_dtype not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT])
1606 or (input_dtype == DType.FLOAT and output_dtype not in [DType.INT8, DType.INT16, DType.INT32])
1607 ):
1608 error_result = True
1609
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001610 else:
1611 if output_dtype != input_dtype:
1612 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001613
1614 info_dict = {
1615 "error_name": error_name,
1616 "error_result": error_result,
1617 "error_reason": error_reason,
1618 "param_reqs": param_reqs
1619 }
1620 return info_dict
1621
1622 @staticmethod
1623 def evWrongRank(check=False, **kwargs):
1624 all_ranks = (1, 2, 3, 4, 5)
1625
1626 # Make a list of incorrect ranks
1627 assert 'op' in kwargs
1628 op = kwargs['op']
1629 rmin, rmax = op['rank']
1630 rank_range = range(rmin, rmax + 1)
1631 incorrect_ranks = list(set(all_ranks) - set(rank_range))
Matthew Haddonc2025212021-10-08 21:21:05 +01001632 # Remove small incorrect ranks to avoid index errors
1633 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
Matthew Haddon848efb42021-09-09 12:30:53 +01001634 # Set minimum incorrect rank to 3 to avoid index error
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001635 if op['op'] in [Op.RESIZE]:
Matthew Haddon848efb42021-09-09 12:30:53 +01001636 incorrect_ranks = [3, 5]
Matthew Haddon01c359d2021-10-15 16:30:48 +01001637 if op['op'] in [Op.TRANSPOSE]:
1638 incorrect_ranks = [7, 8]
Matthew Haddon848efb42021-09-09 12:30:53 +01001639
1640 error_name = ErrorIf.WrongRank
1641 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1642 error_result = False
1643 error_reason = "Rank not supported for this operator"
1644
1645 if check:
1646 input_shape = kwargs['input_shape']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001647
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001648 if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4:
Matthew Haddon848efb42021-09-09 12:30:53 +01001649 error_result = True
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001650 elif op['op'] == Op.FULLY_CONNECTED and len(input_shape) != 2:
1651 error_result = True
1652 elif op['op'] == Op.MATMUL and len(input_shape) != 3:
1653 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001654 else:
1655 if len(input_shape) not in rank_range:
1656 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001657
1658 info_dict = {
1659 "error_name": error_name,
1660 "error_result": error_result,
1661 "error_reason": error_reason,
1662 "param_reqs": param_reqs
1663 }
1664 return info_dict
1665
1666 @staticmethod
1667 def evWrongInputList(check=False, **kwargs):
1668 error_name = ErrorIf.WrongInputList
1669 param_reqs = {"rank": None, "dtype": None, "shape": None}
1670 error_result = False
1671 error_reason = "Op input list does not match expected input"
1672
1673 if check:
1674 op = kwargs['op']
1675 input_list = kwargs['input_list']
1676 num_operands = kwargs['num_operands']
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001677 if op['op'] in [Op.SCATTER, Op.GATHER]:
1678 # SCATTER/GATHER add an indices input tensor in their build functions
1679 num_operands += 1
Kevin Chengfe392ce2021-10-18 21:51:55 +00001680 if len(input_list) != num_operands:
1681 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001682
1683 info_dict = {
1684 "error_name": error_name,
1685 "error_result": error_result,
1686 "error_reason": error_reason,
1687 "param_reqs": param_reqs
1688 }
1689 return info_dict
1690
1691 @staticmethod
1692 def evWrongOutputList(check=False, **kwargs):
1693 error_name = ErrorIf.WrongOutputList
1694 param_reqs = {"rank": None, "dtype": None, "shape": None}
1695 error_result = False
1696 error_reason = "Op output list does not match expected output"
1697
1698 if check:
1699 output_list = kwargs['output_list']
1700 # Note this will be incorrect if an operator returns more than one output
1701 if len(output_list) != 1:
1702 error_result = True
1703
1704 info_dict = {
1705 "error_name": error_name,
1706 "error_result": error_result,
1707 "error_reason": error_reason,
1708 "param_reqs": param_reqs
1709 }
1710 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01001711
1712 @staticmethod
1713 def evMaxDimExceeded(check=False, **kwargs):
1714 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01001715 param_reqs = {
1716 "rank": [4,4],
1717 "dtype": [DType.INT8],
1718 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
1719 }
Matthew Haddone86fd342021-09-07 16:12:21 +01001720 error_result = False
1721 error_reason = "At least one maximum dimension is larger than 16384"
1722
1723 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001724 input_shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001725 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
1726 if ((input_shape[1] > 16384) or
1727 (input_shape[2] > 16384) or
1728 (output_shape[0] > 16384) or
1729 (output_shape[1] > 16384)):
1730 error_result = True
1731
1732 info_dict = {
1733 "error_name": error_name,
1734 "error_result": error_result,
1735 "error_reason": error_reason,
1736 "param_reqs": param_reqs
1737 }
1738 return info_dict
1739
1740 @staticmethod
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001741 def evBatchMismatch(check=False, **kwargs):
1742 error_name = ErrorIf.BatchMismatch
1743 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1744 error_result = False
1745 error_reason = "Input batch size not equal to output batch size"
1746
1747 assert 'op' in kwargs
1748 op = kwargs['op']
1749 rmin, rmax = op['rank']
1750 rank_range = range(rmin, rmax + 1)
1751
1752 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001753 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001754 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1755
1756 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
1757 error_result = True
1758
1759 info_dict = {
1760 "error_name": error_name,
1761 "error_result": error_result,
1762 "error_reason": error_reason,
1763 "param_reqs": param_reqs
1764 }
1765 return info_dict
1766
1767 @staticmethod
1768 def evChannelMismatch(check=False, **kwargs):
1769 error_name = ErrorIf.ChannelMismatch
1770 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1771 error_result = False
1772 error_reason = "Input channel size not equal to output channel size"
1773
1774 assert 'op' in kwargs
1775 op = kwargs['op']
1776 rmin, rmax = op['rank']
1777 rank_range = range(rmin, rmax + 1)
1778
1779 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001780 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001781 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1782 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
1783 error_result = True
1784
1785 info_dict = {
1786 "error_name": error_name,
1787 "error_result": error_result,
1788 "error_reason": error_reason,
1789 "param_reqs": param_reqs
1790 }
1791 return info_dict
1792
1793 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001794 def evStrideSmallerEqualZero(check=False, **kwargs):
1795 error_name = ErrorIf.StrideSmallerEqualZero
1796 param_reqs = {"rank": None, "dtype": None, "shape": None}
1797 error_result = False
1798 error_reason = "Stride value smaller than or equal zero"
1799
1800 if check:
1801 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001802 output_dtype = kwargs['output_dtype']
1803 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
1804 stride = kwargs['stride'] # Work around wrong input/output type tests
1805 elif output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001806 stride = kwargs['stride_fp']
Matthew Haddon848efb42021-09-09 12:30:53 +01001807 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1808 stride = kwargs['stride_fp'] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01001809 else:
1810 stride = kwargs['stride']
1811
1812 if min(stride) <= 0:
1813 error_result = True
1814
1815 info_dict = {
1816 "error_name": error_name,
1817 "error_result": error_result,
1818 "error_reason": error_reason,
1819 "param_reqs": param_reqs
1820 }
1821 return info_dict
1822
1823 @staticmethod
1824 def evStrideLargerEqualMax(check=False, **kwargs):
1825 error_name = ErrorIf.StrideLargerEqualMax
1826 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1827 error_result = False
1828 error_reason = "Stride value larger than or equal to maximum value"
1829
1830 if check:
1831 shift = kwargs['shift']
1832 input_dtype = kwargs['input_dtype']
1833 stride = kwargs['stride']
1834 if input_dtype in [DType.INT8, DType.INT16]:
1835 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1836 error_result = True
1837 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1838 error_result = True
1839
1840 info_dict = {
1841 "error_name": error_name,
1842 "error_result": error_result,
1843 "error_reason": error_reason,
1844 "param_reqs": param_reqs
1845 }
1846 return info_dict
1847
1848
1849 @staticmethod
1850 def evStrideLargerDimension(check=False, **kwargs):
1851 error_name = ErrorIf.StrideLargerDimension
1852 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1853 error_result = False
1854 error_reason = "Stride value larger than or equal to H/W dimension"
1855
1856 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001857 shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001858 input_dtype = kwargs['input_dtype']
1859 stride = kwargs['stride_fp']
1860
1861 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1862 error_result = True
1863
1864 info_dict = {
1865 "error_name": error_name,
1866 "error_result": error_result,
1867 "error_reason": error_reason,
1868 "param_reqs": param_reqs
1869 }
1870 return info_dict
1871
1872
1873 @staticmethod
1874 def evOffsetSmallerEqualMin(check=False, **kwargs):
1875 error_name = ErrorIf.OffsetSmallerEqualMin
1876 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1877 error_result = False
1878 error_reason = "Offset value smaller than or equal to minimum value"
1879
1880 if check:
1881 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001882 output_dtype = kwargs['output_dtype']
1883 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001884 offset = kwargs['offset_fp']
1885 else:
1886 offset = kwargs['offset']
1887
1888 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
1889 error_result = True
1890 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
1891 error_result = True
1892
1893 info_dict = {
1894 "error_name": error_name,
1895 "error_result": error_result,
1896 "error_reason": error_reason,
1897 "param_reqs": param_reqs
1898 }
1899 return info_dict
1900
1901 @staticmethod
1902 def evOffsetLargerEqualMax(check=False, **kwargs):
1903 error_name = ErrorIf.OffsetLargerEqualMax
1904 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1905 error_result = False
1906 error_reason = "Offset value larger than or equal to maximum value"
1907
1908 if check:
1909 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001910 output_dtype = kwargs['output_dtype']
1911 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001912 offset = kwargs['offset_fp']
1913 else:
1914 offset = kwargs['offset']
1915
1916 if shift >= 0:
1917 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
1918 error_result = True
1919
1920 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
1921 error_result = True
1922 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
1923 error_result = True
1924
1925 info_dict = {
1926 "error_name": error_name,
1927 "error_result": error_result,
1928 "error_reason": error_reason,
1929 "param_reqs": param_reqs
1930 }
1931 return info_dict
1932
1933 @staticmethod
1934 def evShiftNotZero(check=False, **kwargs):
1935 error_name = ErrorIf.ShiftNotZero
1936 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1937 error_result = False
1938 error_reason = "Shift value must be zero for float input"
1939
1940 if check:
1941 shift = kwargs['shift']
1942 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001943 output_dtype = kwargs['output_dtype']
1944 if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
Matthew Haddone86fd342021-09-07 16:12:21 +01001945 error_result = True
1946
1947 info_dict = {
1948 "error_name": error_name,
1949 "error_result": error_result,
1950 "error_reason": error_reason,
1951 "param_reqs": param_reqs
1952 }
1953 return info_dict
1954
1955
1956 @staticmethod
1957 def evShiftSmallerOne(check=False, **kwargs):
1958 error_name = ErrorIf.ShiftSmallerOne
1959 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1960 error_result = False
1961 error_reason = "Shift value smaller than one"
1962
1963 if check:
1964 shift = kwargs['shift']
1965 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001966 output_dtype = kwargs['output_dtype']
1967 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001968 error_result = True
1969
1970 info_dict = {
1971 "error_name": error_name,
1972 "error_result": error_result,
1973 "error_reason": error_reason,
1974 "param_reqs": param_reqs
1975 }
1976 return info_dict
1977
1978 @staticmethod
1979 def evShiftLargerEleven(check=False, **kwargs):
1980 error_name = ErrorIf.ShiftLargerEleven
1981 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1982 error_result = False
1983 error_reason = "Shift value larger than eleven"
1984
1985 if check:
1986 shift = kwargs['shift']
1987 if shift > 11:
1988 error_result = True
1989
1990 info_dict = {
1991 "error_name": error_name,
1992 "error_result": error_result,
1993 "error_reason": error_reason,
1994 "param_reqs": param_reqs
1995 }
1996 return info_dict
1997
1998
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001999 @staticmethod
2000 def evRankMismatch(check=False, **kwargs):
2001 error_name = ErrorIf.RankMismatch
2002 param_reqs = {"rank": None, "dtype": None, "shape": None}
2003 error_result = False
2004 error_reason = "Input Rank does not match output rank"
2005
2006 if check:
2007 input1_shape = kwargs['input1'].shape
2008 input2_shape = kwargs['input2'].shape
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002009 # In case of SELECT op
2010 input3_shape = kwargs['input3'].shape if 'input3' in kwargs else input2_shape
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002011 output_shape = kwargs['result_tensor'].shape
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002012 if (
2013 (len(input1_shape) != len(output_shape)) or
2014 (len(input2_shape) != len(output_shape)) or
2015 (len(input3_shape) != len(output_shape))
2016 ):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002017 error_result = True
2018
2019 info_dict = {
2020 "error_name": error_name,
2021 "error_result": error_result,
2022 "error_reason": error_reason,
2023 "param_reqs": param_reqs
2024 }
2025 return info_dict
2026
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002027 @staticmethod
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002028 def evDimensionMismatch(check=False, **kwargs):
2029 error_name = ErrorIf.DimensionMismatch
2030 param_reqs = {"rank": None, "dtype": None, "shape": None}
2031 error_result = False
2032 error_reason = "Input Dimensions do not match output"
2033
2034 if check:
2035 input1_shape = kwargs['input1'].shape
2036 input2_shape = kwargs['input2'].shape
2037 # In case of SELECT op
2038 input3_shape = kwargs['input3'].shape if 'input3' in kwargs else input2_shape
2039 output_shape = kwargs['result_tensor'].shape
2040 for i in range(min(len(input1_shape), len(input2_shape), len(input3_shape))):
2041 if (
2042 (input1_shape[i] != 1 and input1_shape[i] != output_shape[i]) or
2043 (input2_shape[i] != 1 and input2_shape[i] != output_shape[i]) or
2044 (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
2045 ):
2046 error_result = True
2047
2048 info_dict = {
2049 "error_name": error_name,
2050 "error_result": error_result,
2051 "error_reason": error_reason,
2052 "param_reqs": param_reqs
2053 }
2054 return info_dict
2055
2056 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002057 def evInputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002058 op = kwargs['op']
2059 inputDtypes = op['types'].copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002060 # If inputDtypes is a list then only the first two elements are INT8 inputs
2061 if isinstance(inputDtypes, list):
2062 inputDtypes = inputDtypes[2:]
2063
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002064 if DType.INT8 in inputDtypes:
2065 inputDtypes.remove(DType.INT8)
2066 if DType.UINT8 in inputDtypes:
2067 inputDtypes.remove(DType.UINT8)
2068
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002069 error_name = ErrorIf.InputZeroPointNotZero
2070 param_reqs = {
2071 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002072 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002073 "shape": None
2074 }
2075 error_result = False
2076 error_reason = "Input DType not INT8 and zero point not 0"
2077
2078 if check:
2079 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01002080 if isinstance(kwargs['qinfo'], tuple):
2081 qinfo = kwargs['qinfo']
2082 input_zero_point = qinfo[0]
2083 else:
2084 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
2085 qinfo = kwargs['qinfo'].ints
2086 input_zero_point = qinfo[0][1]
2087
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002088 if op['op'] == Op.MATMUL:
2089 input1_dtype = kwargs['input_dtype']
2090 input2_dtype = kwargs['input2_dtype']
2091 qinfo = kwargs['qinfo'].ints
2092 input1_zero_point = qinfo[0][1]
2093 input2_zero_point = qinfo[1][1]
2094 if (input1_dtype != DType.INT8 and input1_zero_point != 0) or (input2_dtype != DType.INT8 and input2_zero_point != 0):
2095 error_result = True
2096 else:
2097 if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
2098 error_result = True
2099
2100 info_dict = {
2101 "error_name": error_name,
2102 "error_result": error_result,
2103 "error_reason": error_reason,
2104 "param_reqs": param_reqs
2105 }
2106 return info_dict
2107
2108
2109 @staticmethod
2110 def evWeightZeroPointNotZero(check=False, **kwargs):
2111 op = kwargs['op']
2112
2113 # exclude inputs with INT8 weights
2114 inputDtypes = [t for t in op['types']
2115 if not isinstance(t, list) or t[1] != DType.INT8]
2116
2117 error_name = ErrorIf.WeightZeroPointNotZero
2118 param_reqs = {
2119 "rank": None,
2120 "dtype": inputDtypes,
2121 "shape": None
2122 }
2123 error_result = False
2124 error_reason = "Weight DType not INT8 and zero point not 0"
2125
2126 if check:
2127 weight_dtype = kwargs['weight_dtype']
2128 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
2129 qinfo = kwargs['qinfo'].ints
2130 weight_zero_point = qinfo[1][1]
2131 if weight_dtype != DType.INT8 and weight_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002132 error_result = True
2133
2134 info_dict = {
2135 "error_name": error_name,
2136 "error_result": error_result,
2137 "error_reason": error_reason,
2138 "param_reqs": param_reqs
2139 }
2140 return info_dict
2141
2142
2143 @staticmethod
2144 def evOutputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002145 op = kwargs['op']
2146 inputDtypes = op['types'].copy()
2147 if DType.INT8 in inputDtypes:
2148 inputDtypes.remove(DType.INT8)
2149 if DType.UINT8 in inputDtypes:
2150 inputDtypes.remove(DType.UINT8)
2151
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002152 error_name = ErrorIf.OutputZeroPointNotZero
2153 param_reqs = {
2154 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002155 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002156 "shape": None
2157 }
2158 error_result = False
2159 error_reason = "Output DType not INT8 and zero point not 0"
2160
2161 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002162 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01002163 output_dtype = kwargs['output_dtype']
2164 if isinstance(kwargs['qinfo'], tuple):
2165 qinfo = kwargs['qinfo']
2166 output_zero_point = qinfo[1]
2167 else:
2168 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
2169 qinfo = kwargs['qinfo'].ints
2170 output_zero_point = qinfo[1][1]
2171 if op['op'] == Op.AVG_POOL2D:
2172 if input_dtype != DType.INT8 and output_zero_point != 0:
2173 error_result = True
2174 elif output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002175 error_result = True
2176
2177 info_dict = {
2178 "error_name": error_name,
2179 "error_result": error_result,
2180 "error_reason": error_reason,
2181 "param_reqs": param_reqs
2182 }
2183 return info_dict
2184
Matthew Haddond6ce7252021-09-29 15:35:44 +01002185 @staticmethod
2186 def evAxisSmallerZero(check=False, **kwargs):
2187 error_name = ErrorIf.AxisSmallerZero
2188 param_reqs = {"rank": None, "dtype": None, "shape": None}
2189 error_result = False
2190 error_reason = "Axis smaller than zero"
2191
2192 if check:
2193 axis = kwargs['axis']
2194 if axis < 0:
2195 error_result = True
2196
2197 info_dict = {
2198 "error_name": error_name,
2199 "error_result": error_result,
2200 "error_reason": error_reason,
2201 "param_reqs": param_reqs
2202 }
2203 return info_dict
2204
2205
2206 @staticmethod
2207 def evAxisLargerRank(check=False, **kwargs):
2208 error_name = ErrorIf.AxisLargerRank
2209 param_reqs = {"rank": None, "dtype": None, "shape": None}
2210 error_result = False
2211 error_reason = "Axis larger than rank"
2212
2213 if check:
2214 axis = kwargs['axis']
2215 shape = kwargs['input_shape']
2216 if axis > len(shape):
2217 error_result = True
2218
2219 info_dict = {
2220 "error_name": error_name,
2221 "error_result": error_result,
2222 "error_reason": error_reason,
2223 "param_reqs": param_reqs
2224 }
2225 return info_dict
2226
2227
2228 @staticmethod
2229 def evShapeOfAxisNotOne(check=False, **kwargs):
2230 error_name = ErrorIf.ShapeOfAxisNotOne
2231 param_reqs = {"rank": None, "dtype": None, "shape": None}
2232 error_result = False
2233 error_reason = "shape[axis] is not equal to 1"
2234
2235 if check:
2236 axis = kwargs['axis']
2237 shape = kwargs['output_shape']
2238 if (0 <= axis < len(shape)) and shape[axis] != 1:
2239 error_result = True
2240
2241 info_dict = {
2242 "error_name": error_name,
2243 "error_result": error_result,
2244 "error_reason": error_reason,
2245 "param_reqs": param_reqs
2246 }
2247 return info_dict
2248
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002249
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002250 @staticmethod
2251 def evPadSmallerZero(check=False, **kwargs):
2252 error_name = ErrorIf.PadSmallerZero
2253 param_reqs = {"rank": None, "dtype": None, "shape": None}
2254 error_result = False
2255 error_reason = "At least one pad is smaller than zero"
2256
2257 if check:
Matthew Haddone807aae2021-10-11 18:12:58 +01002258 op = kwargs['op']
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002259 pad = kwargs['pad']
Matthew Haddone807aae2021-10-11 18:12:58 +01002260 if op['op'] == Op.PAD:
2261 for padding in pad:
2262 if min(padding) < 0:
2263 error_result = True
2264 else:
2265 if min(pad) < 0:
2266 error_result = True
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002267
2268 info_dict = {
2269 "error_name": error_name,
2270 "error_result": error_result,
2271 "error_reason": error_reason,
2272 "param_reqs": param_reqs
2273 }
2274 return info_dict
2275
2276
2277 @staticmethod
2278 def evPadLargerEqualKernel(check=False, **kwargs):
2279 error_name = ErrorIf.PadLargerEqualKernel
2280 param_reqs = {"rank": None, "dtype": None, "shape": None}
2281 error_result = False
2282 error_reason = "At least one pad is larger than kernel dimension"
2283
2284 if check:
2285 pad = kwargs['pad']
2286 kernel = kwargs['kernel']
2287 if min(pad) > 0 and min(kernel) > 1:
2288 if pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]:
2289 error_result = True
2290
2291 info_dict = {
2292 "error_name": error_name,
2293 "error_result": error_result,
2294 "error_reason": error_reason,
2295 "param_reqs": param_reqs
2296 }
2297 return info_dict
2298
2299 @staticmethod
2300 def evPoolingOutputShapeMismatch(check=False, **kwargs):
2301 error_name = ErrorIf.PoolingOutputShapeMismatch
2302 param_reqs = {"rank": None, "dtype": None, "shape": None}
2303 error_result = False
2304 error_reason = "Mismatch between output shape provided and expected output shape"
2305
2306 if check:
2307 pad = kwargs['pad']
2308 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
2309
2310 kernel = kwargs['kernel']
2311 kernel_y, kernel_x = kernel[0], kernel[1]
2312
2313 input_shape = kwargs['input_shape']
2314 IH, IW = input_shape[1], input_shape[2]
2315
2316 output_shape = kwargs['output_shape']
2317 OH, OW = output_shape[1], output_shape[2]
2318
2319 stride = kwargs['stride']
2320 stride_y, stride_x = stride[0], stride[1]
2321
2322 # calculate correct height, width dimensions
2323 if stride_x != 0 and stride_y != 0:
2324 y_correct = (IH + pad_top + pad_bottom + stride_y - kernel_y) // stride_y
2325 x_correct = (IW + pad_left + pad_right + stride_x - kernel_x) // stride_x
2326
2327 # ensure parameters are valid
2328 params_valid = (min(kernel) >= 1 and min(stride) >= 1 and min(pad) >= 0
2329 and not (pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]))
2330
2331 if params_valid and (OH != y_correct or OW != x_correct):
2332 error_result = True
2333
2334 info_dict = {
2335 "error_name": error_name,
2336 "error_result": error_result,
2337 "error_reason": error_reason,
2338 "param_reqs": param_reqs
2339 }
2340 return info_dict
2341
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002342 @staticmethod
2343 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
2344 error_name = ErrorIf.ArgmaxOutputShapeMismatch
2345 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2346 error_result = False
2347 error_reason = "Mismatch between output shape provided and expected output shape"
2348
2349 if check:
2350 output_shape = kwargs['output_shape']
2351 input_shape = kwargs['input_shape']
2352 axis = kwargs['axis']
2353
2354 dimension_match = True
2355 axis_shift = 0
2356
2357 # Check that rank is correct before trying to check dimensions
2358 if (len(input_shape) - 1) == len(output_shape):
2359 for i in range(len(input_shape)):
2360 if i == axis:
2361 axis_shift = 1
2362 continue
2363 if input_shape[i] != output_shape[i - axis_shift]:
2364 dimension_match = False
2365
2366 if not dimension_match:
2367 error_result = True
2368
2369 info_dict = {
2370 "error_name": error_name,
2371 "error_result": error_result,
2372 "error_reason": error_reason,
2373 "param_reqs": param_reqs
2374 }
2375 return info_dict
2376
2377 @staticmethod
2378 def evArgmaxOutputRankMismatch(check=False, **kwargs):
2379 error_name = ErrorIf.ArgmaxOutputRankMismatch
2380 param_reqs = {"rank": None, "dtype": None, "shape": None}
2381 error_result = False
2382 error_reason = "Mismatch between output shape provided and expected output shape"
2383
2384 if check:
2385 output_shape = kwargs['output_shape']
2386 input_shape = kwargs['input_shape']
2387 axis = kwargs['axis']
2388 valid_params = axis >= 0 and axis < len(input_shape)
2389
2390 if valid_params and (len(input_shape) - 1) != len(output_shape):
2391 error_result = True
2392
2393 info_dict = {
2394 "error_name": error_name,
2395 "error_result": error_result,
2396 "error_reason": error_reason,
2397 "param_reqs": param_reqs
2398 }
2399 return info_dict
2400
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002401
2402 @staticmethod
2403 def evKernelSmallerOne(check=False, **kwargs):
2404 error_name = ErrorIf.KernelSmallerOne
2405 param_reqs = {"rank": None, "dtype": None, "shape": None}
2406 error_result = False
2407 error_reason = "At least one kernel dimension is smaller than zero"
2408
2409 if check:
2410 kernel = kwargs['kernel']
2411 if min(kernel) < 1:
2412 error_result = True
2413
2414 info_dict = {
2415 "error_name": error_name,
2416 "error_result": error_result,
2417 "error_reason": error_reason,
2418 "param_reqs": param_reqs
2419 }
2420 return info_dict
2421
2422 @staticmethod
2423 def evStrideSmallerOne(check=False, **kwargs):
2424 error_name = ErrorIf.StrideSmallerOne
2425 param_reqs = {"rank": None, "dtype": None, "shape": None}
2426 error_result = False
2427 error_reason = "At least one stride dimension is smaller than zero"
2428
2429 if check:
2430 stride = kwargs['stride']
2431 if min(stride) < 1:
2432 error_result = True
2433
2434 info_dict = {
2435 "error_name": error_name,
2436 "error_result": error_result,
2437 "error_reason": error_reason,
2438 "param_reqs": param_reqs
2439 }
2440 return info_dict
2441
Matthew Haddonc2025212021-10-08 21:21:05 +01002442 @staticmethod
2443 def evScaleTrue(check=False, **kwargs):
2444 error_name = ErrorIf.ScaleTrue
2445 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
2446 error_result = False
2447 error_reason = "Scale set to true but input type is INT48"
2448
2449 if check:
2450 input_dtype = kwargs['input_dtype']
2451 scale32 = kwargs['scale32']
2452 if scale32 and input_dtype == DType.INT48:
2453 error_result = True
2454
2455 info_dict = {
2456 "error_name": error_name,
2457 "error_result": error_result,
2458 "error_reason": error_reason,
2459 "param_reqs": param_reqs
2460 }
2461 return info_dict
2462
2463 @staticmethod
2464 def evScaleNotTrue(check=False, **kwargs):
2465 error_name = ErrorIf.ScaleNotTrue
2466 param_reqs = {"rank": None, "dtype": None, "shape": None}
2467 error_result = False
2468 error_reason = "Scale set to false but double round set to true"
2469
2470 if check:
2471 scale32 = kwargs['scale32']
2472 double_round = kwargs['double_round']
2473 if not scale32 and double_round:
2474 error_result = True
2475
2476 info_dict = {
2477 "error_name": error_name,
2478 "error_result": error_result,
2479 "error_reason": error_reason,
2480 "param_reqs": param_reqs
2481 }
2482 return info_dict
2483
Matthew Haddone807aae2021-10-11 18:12:58 +01002484 @staticmethod
2485 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
2486 error_name = ErrorIf.TensorSizeInputOutputMismatch
2487 param_reqs = {"rank": None, "dtype": None, "shape": None}
2488 error_result = False
2489 error_reason = "Input tensor size does not match output tensor size"
2490
2491 if check:
2492 input_shape = kwargs['input_shape']
2493 output_shape = kwargs['output_shape']
2494 input_size = np.prod(input_shape)
2495 output_size = np.prod(output_shape)
2496 if input_size != output_size:
2497 error_result = True
2498
2499 info_dict = {
2500 "error_name": error_name,
2501 "error_result": error_result,
2502 "error_reason": error_reason,
2503 "param_reqs": param_reqs
2504 }
2505 return info_dict
2506
2507 @staticmethod
2508 def evStartSmallerZero(check=False, **kwargs):
2509 error_name = ErrorIf.StartSmallerZero
2510 param_reqs = {"rank": None, "dtype": None, "shape": None}
2511 error_result = False
2512 error_reason = "Starting point smaller than zero"
2513
2514 if check:
2515 input_shape = kwargs['input_shape']
2516 start = kwargs['start']
2517 rank = len(input_shape)
2518 if len(start) == rank:
2519 for index in range(rank):
2520 if start[index] < 0:
2521 error_result = True
2522
2523 info_dict = {
2524 "error_name": error_name,
2525 "error_result": error_result,
2526 "error_reason": error_reason,
2527 "param_reqs": param_reqs
2528 }
2529 return info_dict
2530
2531
2532 @staticmethod
2533 def evSizeSmallerEqualZero(check=False, **kwargs):
2534 error_name = ErrorIf.SizeSmallerEqualZero
2535 param_reqs = {"rank": None, "dtype": None, "shape": None}
2536 error_result = False
2537 error_reason = "Size smaller than or equal to zero"
2538
2539 if check:
2540 input_shape = kwargs['input_shape']
2541 size = kwargs['size']
2542 rank = len(input_shape)
2543 if len(size) == rank:
2544 for index in range(rank):
2545 if size[index] <= 0:
2546 error_result = True
2547
2548 info_dict = {
2549 "error_name": error_name,
2550 "error_result": error_result,
2551 "error_reason": error_reason,
2552 "param_reqs": param_reqs
2553 }
2554 return info_dict
2555
2556
2557 @staticmethod
2558 def evStartSizeOutsideBounds(check=False, **kwargs):
2559 error_name = ErrorIf.StartSizeOutsideBounds
2560 param_reqs = {"rank": None, "dtype": None, "shape": None}
2561 error_result = False
2562 error_reason = "starting point plus size larger than input dimension"
2563
2564 if check:
2565 input_shape = kwargs['input_shape']
2566 start = kwargs['start']
2567 size = kwargs['size']
2568 rank = len(input_shape)
2569 if len(start) == rank and len(size) == rank:
2570 for index in range(rank):
2571 if start[index] + size[index] > input_shape[index]:
2572 error_result = True
2573
2574 info_dict = {
2575 "error_name": error_name,
2576 "error_result": error_result,
2577 "error_reason": error_reason,
2578 "param_reqs": param_reqs
2579 }
2580 return info_dict
2581
2582
2583 @staticmethod
2584 def evSizeOutputShapeMismatch(check=False, **kwargs):
2585 error_name = ErrorIf.SizeOutputShapeMismatch
2586 param_reqs = {"rank": None, "dtype": None, "shape": None}
2587 error_result = False
2588 error_reason = "Size does not match output dimension"
2589
2590 if check:
2591 input_shape = kwargs['input_shape']
2592 output_shape = kwargs['output_shape']
2593 size = kwargs['size']
2594 rank = len(input_shape)
2595 if len(size) == rank:
2596 for index in range(rank):
2597 if size[index] != output_shape[index]:
2598 error_result = True
2599
2600 info_dict = {
2601 "error_name": error_name,
2602 "error_result": error_result,
2603 "error_reason": error_reason,
2604 "param_reqs": param_reqs
2605 }
2606 return info_dict
2607
2608 @staticmethod
2609 def evInputSizeStartLengthMismatch(check=False, **kwargs):
2610 error_name = ErrorIf.InputSizeStartLengthMismatch
2611 param_reqs = {"rank": None, "dtype": None, "shape": None}
2612 error_result = False
2613 error_reason = "rank of input not equal to length of start or size"
2614
2615 if check:
2616 input_shape = kwargs['input_shape']
2617 start = kwargs['start']
2618 size = kwargs['size']
2619 rank = len(input_shape)
2620 if rank != len(start) or rank != len(size):
2621 error_result = True
2622
2623 info_dict = {
2624 "error_name": error_name,
2625 "error_result": error_result,
2626 "error_reason": error_reason,
2627 "param_reqs": param_reqs
2628 }
2629 return info_dict
2630
2631 @staticmethod
2632 def evIndexOutsideBounds(check=False, **kwargs):
2633 error_name = ErrorIf.IndexOutsideBounds
2634 param_reqs = {"rank": None, "dtype": None, "shape": None}
2635 error_result = False
2636 error_reason = "Index outside of allowed bounds"
2637
2638 if check:
2639 input_shape = kwargs['input_shape']
2640 perms = kwargs['perms']
2641 rank = len(input_shape)
2642
2643 for index in perms:
2644 if index < 0 or index > rank:
2645 error_result = True
2646
2647 info_dict = {
2648 "error_name": error_name,
2649 "error_result": error_result,
2650 "error_reason": error_reason,
2651 "param_reqs": param_reqs
2652 }
2653 return info_dict
2654
2655 @staticmethod
2656 def evIndexUsedTwice(check=False, **kwargs):
2657 error_name = ErrorIf.IndexUsedTwice
2658 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2659 error_result = False
2660 error_reason = "Index used multiple times"
2661
2662 if check:
2663 input_shape = kwargs['input_shape']
2664 perms = kwargs['perms']
2665 rank = len(input_shape)
2666
2667 unique_indices = []
2668 for index in perms:
2669 if index in unique_indices:
2670 error_result = True
2671 else:
2672 unique_indices.append(index)
2673
2674 info_dict = {
2675 "error_name": error_name,
2676 "error_result": error_result,
2677 "error_reason": error_reason,
2678 "param_reqs": param_reqs
2679 }
2680 return info_dict
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002681
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002682 @staticmethod
2683 def evMaxSmallerMin(check=False, **kwargs):
2684 error_name = ErrorIf.MaxSmallerMin
2685 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2686 error_result = False
2687 error_reason = "Max value smaller than min value"
2688
2689 if check:
2690 max_val = kwargs['max_val']
2691 min_val = kwargs['min_val']
2692 if max_val < min_val:
2693 error_result = True
2694
2695
2696 info_dict = {
2697 "error_name": error_name,
2698 "error_result": error_result,
2699 "error_reason": error_reason,
2700 "param_reqs": param_reqs
2701 }
2702 return info_dict
2703
2704 @staticmethod
2705 def evConcatInputRankMismatch(check=False, **kwargs):
2706 error_name = ErrorIf.ConcatInputRankMismatch
2707 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2708 error_result = False
2709 error_reason = "Input ranks are not identical"
2710
2711 if check:
2712 inputs = kwargs['inputs']
2713 input_shape = kwargs['input_shape']
2714 for input in inputs:
2715 if len(input.shape) != len(input_shape):
2716 error_result = True
2717
2718 info_dict = {
2719 "error_name": error_name,
2720 "error_result": error_result,
2721 "error_reason": error_reason,
2722 "param_reqs": param_reqs
2723 }
2724 return info_dict
2725
2726 @staticmethod
2727 def evConcatInputDimMismatch(check=False, **kwargs):
2728 error_name = ErrorIf.ConcatInputDimMismatch
2729 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2730 error_result = False
2731 error_reason = "Input dimensions differ on too many axes"
2732
2733 if check:
2734 inputs = kwargs['inputs']
2735 input_shape = kwargs['input_shape']
2736 axis = kwargs['axis']
2737
2738 # Ensure rank is valid before checking dims.
2739 valid_rank = True
2740 for input in inputs:
2741 if len(input.shape) != len(input_shape):
2742 valid_rank = False
2743
2744 if valid_rank:
2745 for input in inputs:
2746 for i, dim in enumerate(input.shape):
2747 if dim != input_shape[i] and axis != i:
2748 error_result = True
2749
2750 info_dict = {
2751 "error_name": error_name,
2752 "error_result": error_result,
2753 "error_reason": error_reason,
2754 "param_reqs": param_reqs
2755 }
2756 return info_dict
2757
Matthew Haddon630c17c2021-10-14 15:05:41 +01002758 @staticmethod
Matthew Haddon01c359d2021-10-15 16:30:48 +01002759 def evConcatShapeSumMismatch(check=False, **kwargs):
2760 error_name = ErrorIf.ConcatShapeSumMismatch
2761 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2762 error_result = False
2763 error_reason = "Sum of dimensions on axis not equal to output dimension"
2764
2765 if check:
2766 inputs = kwargs['inputs']
2767 input_shape = kwargs['input_shape']
2768 output_shape = kwargs['output_shape']
2769 axis = kwargs['axis']
2770
2771 # Ensure rank is valid before checking dims.
2772 valid_params = True
2773 for input in inputs:
2774 if len(input.shape) != len(input_shape):
2775 valid_params = False
2776 if axis < 0 or axis > len(input_shape):
2777 valid_params = False
2778
2779 if valid_params:
2780 axis_dim_sum = 0
2781 for input in inputs:
2782 axis_dim_sum += input.shape[axis]
2783
2784 if axis_dim_sum != output_shape[axis]:
2785 error_result = True
2786
2787
2788 info_dict = {
2789 "error_name": error_name,
2790 "error_result": error_result,
2791 "error_reason": error_reason,
2792 "param_reqs": param_reqs
2793 }
2794 return info_dict
2795
2796 @staticmethod
Matthew Haddon630c17c2021-10-14 15:05:41 +01002797 def evInputListThenGraphMismatch(check=False, **kwargs):
2798 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2799 param_reqs = {"rank": None, "dtype": None, "shape": None}
2800 error_result = False
2801 error_reason = "Input list shape does not match then-graph shape"
2802
2803 if check:
2804 a = kwargs['a']
2805 b = kwargs['b']
2806 basicBlocks = kwargs['basicBlocks']
2807 then_block = basicBlocks[1]
2808 then_inputs = then_block.inputs
2809 then_tens = then_block.tensors
2810 if (a.shape != then_tens[then_inputs[0]].shape) or (b.shape != then_tens[then_inputs[1]].shape):
2811 error_result = True
2812
2813 info_dict = {
2814 "error_name": error_name,
2815 "error_result": error_result,
2816 "error_reason": error_reason,
2817 "param_reqs": param_reqs
2818 }
2819 return info_dict
2820
2821
2822 @staticmethod
2823 def evInputListElseGraphMismatch(check=False, **kwargs):
2824 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2825 param_reqs = {"rank": None, "dtype": None, "shape": None}
2826 error_result = False
2827 error_reason = "Input list shape does not match else-graph shape"
2828
2829 if check:
2830 a = kwargs['a']
2831 b = kwargs['b']
2832 basicBlocks = kwargs['basicBlocks']
2833 else_block = basicBlocks[2]
2834 else_inputs = else_block.inputs
2835 else_tens = else_block.tensors
2836 if (a.shape != else_tens[else_inputs[0]].shape) or (b.shape != else_tens[else_inputs[1]].shape):
2837 error_result = True
2838
2839 info_dict = {
2840 "error_name": error_name,
2841 "error_result": error_result,
2842 "error_reason": error_reason,
2843 "param_reqs": param_reqs
2844 }
2845 return info_dict
2846
2847
2848 @staticmethod
2849 def evOutputListThenGraphMismatch(check=False, **kwargs):
2850 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2851 param_reqs = {"rank": None, "dtype": None, "shape": None}
2852 error_result = False
2853 error_reason = "Output list shape does not match then-graph shape"
2854
2855 if check:
2856 basicBlocks = kwargs['basicBlocks']
2857 cond_block = basicBlocks[0]
2858 cond_outputs = cond_block.outputs
2859 cond_tens = cond_block.tensors
2860 then_block = basicBlocks[1]
2861 then_outputs = then_block.outputs
2862 then_tens = then_block.tensors
2863 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2864 error_result = True
2865
2866 info_dict = {
2867 "error_name": error_name,
2868 "error_result": error_result,
2869 "error_reason": error_reason,
2870 "param_reqs": param_reqs
2871 }
2872 return info_dict
2873
2874
2875 @staticmethod
2876 def evOutputListElseGraphMismatch(check=False, **kwargs):
2877 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2878 param_reqs = {"rank": None, "dtype": None, "shape": None}
2879 error_result = False
2880 error_reason = "Output list shape does not match else-graph shape"
2881
2882 if check:
2883 basicBlocks = kwargs['basicBlocks']
2884 cond_block = basicBlocks[0]
2885 cond_outputs = cond_block.outputs
2886 cond_tens = cond_block.tensors
2887 else_block = basicBlocks[2]
2888 else_outputs = else_block.outputs
2889 else_tens = else_block.tensors
2890 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2891 error_result = True
2892
2893 info_dict = {
2894 "error_name": error_name,
2895 "error_result": error_result,
2896 "error_reason": error_reason,
2897 "param_reqs": param_reqs
2898 }
2899 return info_dict
2900
2901
2902 @staticmethod
2903 def evInputListOutputListMismatch(check=False, **kwargs):
2904 error_name = ErrorIf.InputListOutputListMismatch
2905 param_reqs = {"rank": None, "dtype": None, "shape": None}
2906 error_result = False
2907 error_reason = "Input list does not match output list"
2908
2909 if check:
2910 basicBlocks = kwargs['basicBlocks']
2911 while_block = basicBlocks[0]
2912 while_inputs = while_block.inputs
2913 while_outputs = while_block.outputs
2914 while_tens = while_block.tensors
2915 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2916 error_result = True
2917
2918 info_dict = {
2919 "error_name": error_name,
2920 "error_result": error_result,
2921 "error_reason": error_reason,
2922 "param_reqs": param_reqs
2923 }
2924 return info_dict
2925
2926
2927 @staticmethod
2928 def evInputListCondGraphMismatch(check=False, **kwargs):
2929 error_name = ErrorIf.InputListCondGraphMismatch
2930 param_reqs = {"rank": None, "dtype": None, "shape": None}
2931 error_result = False
2932 error_reason = "Input list does not match cond graph"
2933
2934 if check:
2935 basicBlocks = kwargs['basicBlocks']
2936 while_block = basicBlocks[0]
2937 while_inputs = while_block.inputs
2938 while_tens = while_block.tensors
2939 cond_block = basicBlocks[1]
2940 cond_inputs = cond_block.inputs
2941 cond_tens = cond_block.tensors
2942 if ((while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape) or
2943 (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape)):
2944 error_result = True
2945
2946 info_dict = {
2947 "error_name": error_name,
2948 "error_result": error_result,
2949 "error_reason": error_reason,
2950 "param_reqs": param_reqs
2951 }
2952 return info_dict
2953
2954
2955 @staticmethod
2956 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2957 error_name = ErrorIf.InputListBodyGraphInputMismatch
2958 param_reqs = {"rank": None, "dtype": None, "shape": None}
2959 error_result = False
2960 error_reason = "Input list does not match body graph input"
2961
2962 if check:
2963 basicBlocks = kwargs['basicBlocks']
2964 while_block = basicBlocks[0]
2965 while_inputs = while_block.inputs
2966 while_tens = while_block.tensors
2967 body_block = basicBlocks[2]
2968 body_outputs = body_block.inputs
2969 body_tens = body_block.tensors
2970 if ((while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape) or
2971 (while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape)):
2972 error_result = True
2973
2974 info_dict = {
2975 "error_name": error_name,
2976 "error_result": error_result,
2977 "error_reason": error_reason,
2978 "param_reqs": param_reqs
2979 }
2980 return info_dict
2981
2982
2983 @staticmethod
2984 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2985 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2986 param_reqs = {"rank": None, "dtype": None, "shape": None}
2987 error_result = False
2988 error_reason = "Input list does not match body graph output"
2989
2990 if check:
2991 basicBlocks = kwargs['basicBlocks']
2992 while_block = basicBlocks[0]
2993 while_inputs = while_block.inputs
2994 while_tens = while_block.tensors
2995 body_block = basicBlocks[2]
2996 body_outputs = body_block.outputs
2997 body_tens = body_block.tensors
2998 if ((while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape) or
2999 (while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape)):
3000 error_result = True
3001 info_dict = {
3002 "error_name": error_name,
3003 "error_result": error_result,
3004 "error_reason": error_reason,
3005 "param_reqs": param_reqs
3006 }
3007 return info_dict
3008
3009
3010 @staticmethod
3011 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
3012 error_name = ErrorIf.CondGraphOutputNotMatchingBool
3013 param_reqs = {"rank": None, "dtype": None, "shape": None}
3014 error_result = False
3015 error_reason = "Cond graph output is not a match list of booleans"
3016
3017 if check:
3018 basicBlocks = kwargs['basicBlocks']
3019 cond_block = basicBlocks[1]
3020 cond_outputs = cond_block.outputs
3021 cond_tens = cond_block.tensors
3022 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
3023 error_result = True
3024
3025 info_dict = {
3026 "error_name": error_name,
3027 "error_result": error_result,
3028 "error_reason": error_reason,
3029 "param_reqs": param_reqs
3030 }
3031 return info_dict
3032
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003033
Matthew Haddonb724efc2021-08-25 16:40:29 +01003034class TosaInvalidValidator:
3035
3036 @staticmethod
3037 def ivWrongDataTypeOrModeResize(**kwargs):
3038 input_dtype = kwargs["input_dtype"]
3039 args = kwargs["args"]
3040 mode = args[0]
3041 stride = args[1]
3042 stride_fp = args[4]
3043 output_dtype = args[8]
3044
3045 if mode == ResizeMode.BILINEAR:
3046 # Invalid output data type / Invalid input datatype
3047 return (
3048 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
3049 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
3050 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
3051 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
3052 )
3053 elif mode == ResizeMode.NEAREST:
3054 # Invalid output data type / Invalid input datatype
3055 return (
3056 (input_dtype != output_dtype) or
3057 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
3058 )
3059 else:
3060 # Invalid resize mode
3061 return True
3062
3063 @staticmethod
3064 def ivBadStride(**kwargs):
3065 input_dtype = kwargs["input_dtype"]
3066 args = kwargs["args"]
3067 stride_x = args[1][0]
3068 stride_y = args[1][1]
3069 stride_fp_x = args[4][0]
3070 stride_fp_y = args[4][1]
3071
3072 if input_dtype == DType.FLOAT:
3073 if stride_fp_x <= 0 or stride_fp_y <= 0:
3074 # Negative or zero stride
3075 return True
3076 else:
3077 if stride_x <= 0 or stride_y <= 0:
3078 # Negative or zero stride
3079 return True
3080 return False
3081
3082
Matthew Haddonb724efc2021-08-25 16:40:29 +01003083 @staticmethod
3084 def ivHeightWidthSmallerZero(**kwargs):
3085 opName = kwargs['opName']
3086
3087 inputShapes = kwargs['shapeList']
3088 input = inputShapes[0]
3089 if not opName.endswith("pool2d"):
3090 filter = inputShapes[1]
3091
3092 args = kwargs['args']
3093 strides = args[0]
3094 padding = args[1]
3095 dilations = args[2]
3096 if opName.endswith("pool2d"):
3097 kernel = args[2]
3098
3099 if opName.startswith('conv2d'):
3100 h = (
3101 input[1]
3102 - filter[1]
3103 - (filter[1] - 1) * (dilations[0] - 1)
3104 + padding[0]
3105 + padding[1]
3106 ) // strides[0] + 1
3107
3108 w = (
3109 input[2]
3110 - filter[2]
3111 - (filter[2] - 1) * (dilations[1] - 1)
3112 + padding[2]
3113 + padding[3]
3114 ) // strides[1] + 1
3115 elif opName.startswith("depthwise_conv2d"):
3116 h = (
3117 input[1]
3118 - filter[0]
3119 - (filter[0] - 1) * (dilations[0] - 1)
3120 + padding[0]
3121 + padding[1]
3122 ) // strides[0] + 1
3123
3124 w = (
3125 input[2]
3126 - filter[1]
3127 - (filter[1] - 1) * (dilations[1] - 1)
3128 + padding[2]
3129 + padding[3]
3130 ) // strides[1] + 1
3131 elif opName.endswith("pool2d"):
3132 h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
3133 w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
3134 else:
3135 assert False, "Unrecognized Op"
3136
3137 if h <= 0 or w <= 0:
3138 # Invalid parameter combination
3139 return True
3140 return False
3141
3142 @staticmethod
3143 def ivNonPositiveOutputShape(**kwargs):
3144 args = kwargs['args']
3145 output_shape = args[3]
3146 if output_shape[1] <= 0 or output_shape[2] <= 0:
3147 # Negative output shape
3148 return True
3149 return False
3150
3151
Kevin Cheng550ccc52021-03-03 11:21:43 -08003152
Eric Kunzee5e26762020-10-13 16:11:07 -07003153class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003154 # Maximum rank of tensor supported by test generator.
3155 TOSA_TENSOR_MAX_RANK = 6
3156
Eric Kunzee5e26762020-10-13 16:11:07 -07003157 def __init__(self, args):
3158 self.args = args
3159 self.basePath = args.output_dir
3160 self.random_seed = args.random_seed
3161 self.ser = None
3162 self.rng = np.random.default_rng(self.random_seed)
3163 self.createDynamicOpLists()
3164 self.initOpListDefaults()
3165 self.quantGen = TosaQuantGen()
3166 # Force makeShape to do a specific starting shape
3167 self.targetted_shape = None
3168
3169 def createSerializer(self, opName, testPath):
3170 self.testPath = os.path.join(opName, testPath)
3171
3172 fullPath = os.path.join(self.basePath, self.testPath)
3173 os.makedirs(fullPath, exist_ok=True)
3174 self.ser = ts.TosaSerializer(fullPath)
3175
3176 def getSerializer(self):
3177 return self.ser
3178
3179 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003180 with open(
3181 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
3182 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07003183 fd.write(self.ser.serialize())
3184
Kevin Cheng550ccc52021-03-03 11:21:43 -08003185 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
3186 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07003187
Matthew Haddon74567092021-07-16 15:38:20 +01003188 def resetRNG(self, seed=None):
3189 if seed == None:
3190 seed = self.random_seed + 1
3191 self.rng = np.random.default_rng(seed)
3192
Eric Kunzee5e26762020-10-13 16:11:07 -07003193 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07003194 if dtype == DType.BOOL:
3195 np_dt = np.bool
3196 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07003197 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07003198 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07003199 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003200 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01003201 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
3202 elif dtype == DType.UINT8:
3203 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003204 elif dtype == DType.INT16:
3205 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
3206 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003207 return np.int32(
3208 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
3209 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003210 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003211 return np.int64(
3212 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
3213 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003214 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01003215 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003216 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003217 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003218
Kevin Cheng989cb052021-04-28 16:29:44 -07003219 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07003220 placeholders = []
3221
Kevin Cheng989cb052021-04-28 16:29:44 -07003222 assert len(shape_list) == len(dtype_list)
3223
3224 for idx, shape in enumerate(shape_list):
3225 arr = self.getRandTensor(shape, dtype_list[idx])
3226 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07003227
3228 return placeholders
3229
Kevin Cheng989cb052021-04-28 16:29:44 -07003230 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07003231 consts = []
3232
Kevin Cheng989cb052021-04-28 16:29:44 -07003233 assert len(shape_list) == len(dtype_list)
3234
3235 for idx, shape in enumerate(shape_list):
3236 arr = self.getRandTensor(shape, dtype_list[idx])
3237 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07003238
3239 return consts
3240
3241 def makeShape(self, rank):
3242 if self.targetted_shape:
3243 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003244 return np.int32(
3245 self.rng.integers(
3246 low=self.args.tensor_shape_range[0],
3247 high=self.args.tensor_shape_range[1],
3248 size=rank,
3249 )
3250 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003251
3252 def setTargetShape(self, shape):
3253 self.targetted_shape = shape
3254
3255 def randInt(self, low=0, high=256):
3256 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
3257
3258 def getRandNumberDType(self, dtype):
3259 if dtype == DType.FLOAT:
3260 return self.rng.random()
3261 elif dtype == DType.BOOL:
3262 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07003263 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07003264 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07003265 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07003266 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01003267 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07003268 elif dtype == DType.INT16:
3269 low, high = (-32768, 32768)
3270 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003271 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07003272 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003273 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07003274 # Special size
3275 return np.int64(self.rng.integers(low, high, size=1))[0]
3276 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003277 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003278
3279 return np.int32(self.rng.integers(low, high, size=1))[0]
3280
3281 def shapeStr(self, shape):
3282
3283 sStr = []
3284 # Convert to strings
3285 for i in shape:
3286 sStr.append(str(i))
3287
Kevin Cheng550ccc52021-03-03 11:21:43 -08003288 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003289
3290 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07003291 if isinstance(t, list):
3292 assert len(t) >= 2
3293 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07003294 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07003295 if t == DType.BOOL:
3296 return "b"
3297 elif t == DType.INT4:
3298 return "i4"
3299 elif t == DType.INT8:
3300 return "i8"
3301 elif t == DType.UINT8:
3302 return "u8"
3303 elif t == DType.INT16:
3304 return "i16"
3305 elif t == DType.INT32:
3306 return "i32"
3307 elif t == DType.INT48:
3308 return "i48"
3309 elif t == DType.FLOAT:
3310 return "float"
3311 else:
3312 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07003313
3314 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003315 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08003316 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07003317 return 4
3318 elif t == DType.INT8:
3319 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08003320 elif t == DType.UINT8:
3321 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07003322 elif t == DType.INT16:
3323 return 16
3324 elif t == DType.INT32:
3325 return 32
3326 elif t == DType.INT48:
3327 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +01003328 elif t == DType.FLOAT:
3329 return 32
3330 elif t == DType.BOOL:
3331 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003332 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003333 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07003334
3335 # Argument generators
3336 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
3337 # Where the string descriptor is used to generate the test name and
3338 # The build_fcn_arg_list is expanded and passed to the operator test
3339 # build function
3340
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003341 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
3342 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
3343
Matthew Haddon848efb42021-09-09 12:30:53 +01003344 # build_placeholder returns an int, ABS/other ops does not
3345 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003346 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
3347 return result_tens
3348 elif op['op'] == Op.IDENTITY:
3349 self.ser.addOperator(op['op'], a.name, result_tens.name, None, qinfo)
3350 return result_tens
3351
3352 # Ensure new output type has correct qinfo
3353 if error_name == ErrorIf.WrongOutputType:
3354 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
3355 qinfo = ts.TosaSerializerQuantInfo()
3356 qinfo.UnaryQuantInfo(
3357 TosaQuantGen.getQinfo(self, a.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
3358 )
3359
3360 # Invalidate Input/Output list for error if checks.
3361 input_list = [a.name]
3362 output_list = [result_tens.name]
3363 pCount, cCount = op["operands"]
3364 num_operands = pCount + cCount
3365 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3366
3367 TosaErrorValidator.evValidateErrorIfs(
3368 self.ser,
3369 validator_fcns,
3370 error_name,
3371 op=op,
3372 input_dtype=a.dtype,
3373 output_dtype=result_tens.dtype,
3374 qinfo = qinfo,
3375 result_tensor = result_tens,
3376 input_list=input_list,
3377 output_list=output_list,
3378 num_operands=num_operands,
3379 )
3380
3381 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003382 return result_tens
3383
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003384 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
3385 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
3386
3387
3388 # Invalidate Input/Output list for error if checks.
3389 input_list = [a.name, b.name]
3390 output_list = [result_tens.name]
3391 pCount, cCount = op["operands"]
3392 num_operands = pCount + cCount
3393 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3394
3395 TosaErrorValidator.evValidateErrorIfs(
3396 self.ser,
3397 validator_fcns,
3398 error_name,
3399 op=op,
3400 input1 = a,
3401 input2 = b,
3402 input_dtype = a.dtype,
3403 output_dtype = result_tens.dtype,
3404 result_tensor = result_tens,
3405 input_list=input_list,
3406 output_list=output_list,
3407 num_operands=num_operands,
3408 )
3409
3410 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07003411 return result_tens
3412
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003413 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003414 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01003415 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003416 return result_tens
3417
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003418 def build_arithmetic_right_shift(self, op, a, b, round, validator_fcns=None, error_name=None):
3419 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
3420
3421 # Invalidate Input/Output list for error if checks.
3422 input_list = [a.name, b.name]
3423 output_list = [result_tens.name]
3424 pCount, cCount = op["operands"]
3425 num_operands = pCount + cCount
3426 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3427
3428 TosaErrorValidator.evValidateErrorIfs(
3429 self.ser,
3430 validator_fcns,
3431 error_name,
3432 op=op,
3433 input1 = a,
3434 input2 = b,
3435 input_dtype = a.dtype,
3436 output_dtype = result_tens.dtype,
3437 result_tensor = result_tens,
3438 input_list=input_list,
3439 output_list=output_list,
3440 num_operands=num_operands,
3441 )
Kevin Chengaee1fac2020-11-11 13:54:06 -08003442
3443 attr = ts.TosaSerializerAttribute()
3444 attr.ArithmeticRightShiftAttribute(round)
3445
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003446 self.ser.addOperator(op['op'], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08003447 return result_tens
3448
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003449 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
3450 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003451
3452 # Special for multiply:
3453 # Force the result to INT32 for INT types
3454 if a.dtype != DType.FLOAT:
3455 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003456 if error_name == ErrorIf.WrongOutputType:
3457 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
3458 outputDType = self.rng.choice(all_dtypes)
3459 result_tens.setDtype(outputDType)
3460
3461 # Invalidate Input/Output list for error if checks.
3462 input_list = [a.name, b.name]
3463 output_list = [result_tens.name]
3464 pCount, cCount = op["operands"]
3465 num_operands = pCount + cCount
3466 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3467
3468 TosaErrorValidator.evValidateErrorIfs(
3469 self.ser,
3470 validator_fcns,
3471 error_name,
3472 op=op,
3473 input1 = a,
3474 input2 = b,
3475 input_dtype = a.dtype,
3476 output_dtype = result_tens.dtype,
3477 result_tensor = result_tens,
3478 input_list=input_list,
3479 output_list=output_list,
3480 num_operands=num_operands,
3481 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003482
Kevin Chengaee1fac2020-11-11 13:54:06 -08003483 attr = ts.TosaSerializerAttribute()
3484 attr.MulAttribute(shift)
3485
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003486 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003487 return result_tens
3488
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003489 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
3490 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003491
Kevin Chengfe392ce2021-10-18 21:51:55 +00003492 attr = ts.TosaSerializerAttribute()
3493 attr.TableAttribute(table)
3494
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003495 # Invalidate Input/Output list for error if checks.
3496 input_list = [a.name]
3497 output_list = [result_tens.name]
3498 pCount, cCount = op["operands"]
3499 num_operands = pCount + cCount
3500 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3501
3502 TosaErrorValidator.evValidateErrorIfs(
3503 self.ser,
3504 validator_fcns,
3505 error_name,
3506 op=op,
3507 input_shape = a.shape,
3508 input_dtype = a.dtype,
3509 output_dtype = result_tens.dtype,
3510 result_tensor = result_tens,
3511 input_list=input_list,
3512 output_list=output_list,
3513 num_operands=num_operands,
3514 )
3515
3516 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003517
3518 return result_tens
3519
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003520 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
3521 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
3522
3523 # Invalidate Input/Output list for error if checks.
3524 input_list = [cond.name, a.name, b.name]
3525 output_list = [result_tens.name]
3526 pCount, cCount = op["operands"]
3527 num_operands = pCount + cCount
3528 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3529
3530 TosaErrorValidator.evValidateErrorIfs(
3531 self.ser,
3532 validator_fcns,
3533 error_name,
3534 op=op,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003535 input1 = cond,
3536 input2 = a,
3537 input3 = b,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003538 input_shape = a.shape,
3539 input_dtype = a.dtype,
3540 output_dtype = result_tens.dtype,
3541 result_tensor = result_tens,
3542 input_list=input_list,
3543 output_list=output_list,
3544 num_operands=num_operands,
3545 )
3546
3547 self.ser.addOperator(op['op'], input_list, output_list,)
Eric Kunzee5e26762020-10-13 16:11:07 -07003548 return result_tens
3549
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003550 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
3551 result_tens = OutputShaper.binaryComparisonOp(self.ser, self.rng, a, b, error_name)
3552
3553 # Invalidate Input/Output list for error if checks.
3554 input_list = [a.name, b.name]
3555 output_list = [result_tens.name]
3556 pCount, cCount = op["operands"]
3557 num_operands = pCount + cCount
3558 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3559
3560 TosaErrorValidator.evValidateErrorIfs(
3561 self.ser,
3562 validator_fcns,
3563 error_name,
3564 op=op,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003565 input1 = a,
3566 input2 = b,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003567 input_shape = a.shape,
3568 input_dtype = a.dtype,
3569 output_shape = result_tens.shape,
3570 output_dtype = result_tens.dtype,
3571 result_tensor = result_tens,
3572 input_list=input_list,
3573 output_list=output_list,
3574 num_operands=num_operands,
3575 )
3576
3577 self.ser.addOperator(op['op'], input_list, output_list,)
Eric Kunzee5e26762020-10-13 16:11:07 -07003578 return result_tens
3579
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003580 def build_argmax(self, op, a, axis, validator_fcns, error_name):
3581 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
3582
3583 # Invalidate Input/Output list for error if checks.
3584 input_list = [a.name]
3585 output_list = [result_tens.name]
3586 pCount, cCount = op["operands"]
3587 num_operands = pCount + cCount
3588 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3589
3590 TosaErrorValidator.evValidateErrorIfs(
3591 self.ser,
3592 validator_fcns,
3593 error_name,
3594 op=op,
3595 axis=axis,
3596 input_shape = a.shape,
3597 input_dtype = a.dtype,
3598 output_shape = result_tens.shape,
3599 output_dtype = result_tens.dtype,
3600 result_tensor = result_tens,
3601 input_list=input_list,
3602 output_list=output_list,
3603 num_operands=num_operands,
3604 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003605
3606 attr = ts.TosaSerializerAttribute()
3607 attr.AxisAttribute(axis)
3608
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003609 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003610 return result_tens
3611
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003612 def build_pool2d(self, op, input, stride, pad, kernel, validator_fcns=None, error_name=None, qinfo=None):
3613 result_tens = OutputShaper.pool2dOp(self.ser, self.rng, input, kernel, stride, pad, error_name)
3614
3615 # Ensure new output type has correct qinfo
3616 if error_name == ErrorIf.WrongInputType:
3617 if input.dtype not in [DType.INT8, DType.UINT8]:
3618 qinfo = ts.TosaSerializerQuantInfo()
3619 qinfo.UnaryQuantInfo(
3620 TosaQuantGen.getQinfo(self, input.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
3621 )
3622
3623 # Invalidate Input/Output list for error if checks.
3624 input_list = [input.name]
3625 output_list = [result_tens.name]
3626 pCount, cCount = op["operands"]
3627 num_operands = pCount + cCount
3628 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3629
3630 TosaErrorValidator.evValidateErrorIfs(
3631 self.ser,
3632 validator_fcns,
3633 error_name,
3634 op=op,
3635 input_shape=input.shape,
3636 input_dtype=input.dtype,
3637 output_shape=result_tens.shape,
3638 output_dtype=result_tens.dtype,
3639 kernel=kernel,
3640 stride=stride,
3641 pad=pad,
3642 qinfo = qinfo,
3643 result_tensor = result_tens,
3644 input_list=input_list,
3645 output_list=output_list,
3646 num_operands=num_operands,
3647 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003648
3649 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003650 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07003651
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003652 self.ser.addOperator(op['op'], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003653 return result_tens
3654
3655 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003656 assert len(padding) == 4
3657 result_tens = OutputShaper.conv2dOp(
3658 self.ser, ifm, filter, strides, padding, dilations
3659 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003660
3661 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003662 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07003663
Kevin Cheng550ccc52021-03-03 11:21:43 -08003664 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003665 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003666 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003667 return result_tens
3668
Kevin Cheng1533b852021-09-01 12:51:58 -07003669 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
3670 assert len(padding) == 6
3671 result_tens = OutputShaper.conv3dOp(
3672 self.ser, ifm, filter, strides, padding, dilations
3673 )
3674
3675 attr = ts.TosaSerializerAttribute()
3676 attr.ConvAttribute(padding, strides, dilations)
3677
3678 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003679 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng1533b852021-09-01 12:51:58 -07003680 )
3681 return result_tens
3682
Kevin Cheng550ccc52021-03-03 11:21:43 -08003683 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07003684 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003685 ):
3686 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07003687 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
3688
3689 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003690 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003691
Kevin Cheng550ccc52021-03-03 11:21:43 -08003692 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003693 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003694 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003695 return result_tens
3696
Kevin Cheng550ccc52021-03-03 11:21:43 -08003697 def build_depthwise_conv2d(
3698 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
3699 ):
3700 result_tens = OutputShaper.depthwiseConv2dOp(
3701 self.ser, ifm, filter, strides, padding, dilations
3702 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003703
3704 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07003705 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07003706
Kevin Cheng550ccc52021-03-03 11:21:43 -08003707 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003708 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003709 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003710 return result_tens
3711
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003712 def build_fully_connected(self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None):
3713 result_tens = OutputShaper.fullyConnectedOp(self.ser, self.rng, ifm, filter, error_name)
3714
3715 # Invalidate Input/Output list for error if checks.
3716 input_list = [ifm.name, filter.name, bias.name]
3717 output_list = [result_tens.name]
3718 pCount, cCount = op["operands"]
3719 num_operands = pCount + cCount
3720 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3721
3722 TosaErrorValidator.evValidateErrorIfs(
3723 self.ser,
3724 validator_fcns,
3725 error_name,
3726 op=op,
3727 input_shape=ifm.shape,
3728 input_dtype=ifm.dtype,
3729 weight_dtype=filter.dtype,
3730 output_shape=result_tens.shape,
3731 output_dtype=result_tens.dtype,
3732 qinfo = qinfo,
3733 result_tensor = result_tens,
3734 input_list=input_list,
3735 output_list=output_list,
3736 num_operands=num_operands,
3737 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003738
Kevin Cheng550ccc52021-03-03 11:21:43 -08003739 self.ser.addOperator(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003740 op['op'], input_list, output_list, None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08003741 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003742 return result_tens
3743
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003744 def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
3745 result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
3746
3747 # Invalidate Input/Output list for error if checks.
3748 input_list = [a.name, b.name]
3749 output_list = [result_tens.name]
3750 pCount, cCount = op["operands"]
3751 num_operands = pCount + cCount
3752 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3753
3754 TosaErrorValidator.evValidateErrorIfs(
3755 self.ser,
3756 validator_fcns,
3757 error_name,
3758 op=op,
3759 input_shape=a.shape,
3760 input_dtype=a.dtype,
3761 input2_shape=b.shape,
3762 input2_dtype=b.dtype,
3763 output_shape=result_tens.shape,
3764 output_dtype=result_tens.dtype,
3765 qinfo = qinfo,
3766 result_tensor = result_tens,
3767 input_list=input_list,
3768 output_list=output_list,
3769 num_operands=num_operands,
3770 )
3771
3772 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003773 return result_tens
3774
Matthew Haddond6ce7252021-09-29 15:35:44 +01003775 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
3776 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
3777
3778 # Invalidate Input/Output list for error if checks.
3779 input_list = [a.name]
3780 output_list = [result_tens.name]
3781 pCount, cCount = op["operands"]
3782 num_operands = pCount + cCount
3783 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3784
3785 TosaErrorValidator.evValidateErrorIfs(
3786 self.ser,
3787 validator_fcns,
3788 error_name,
3789 op=op,
3790 axis = axis,
3791 input_shape = a.shape,
3792 output_shape = result_tens.shape,
3793 input_dtype = a.dtype,
3794 output_dtype = result_tens.dtype,
3795 result_tensor = result_tens,
3796 input_list=input_list,
3797 output_list=output_list,
3798 num_operands=num_operands,
3799 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003800
3801 attr = ts.TosaSerializerAttribute()
3802 attr.AxisAttribute(axis)
3803
Matthew Haddond6ce7252021-09-29 15:35:44 +01003804 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003805 return result_tens
3806
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003807 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
3808 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003809
Jeremy Johnson18e26662021-07-22 16:15:29 +01003810 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07003811
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003812 if error_name == ErrorIf.MaxSmallerMin:
3813 # Make sure the numbers are different to invoke this error
3814 while v[0] == v[1]:
3815 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
3816 max_val = min(v)
3817 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07003818 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003819 max_val = max(v)
3820 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07003821
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003822 # Invalidate Input/Output list for error if checks.
3823 input_list = [a.name]
3824 output_list = [result_tens.name]
3825 pCount, cCount = op["operands"]
3826 num_operands = pCount + cCount
3827 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3828
3829 TosaErrorValidator.evValidateErrorIfs(
3830 self.ser,
3831 validator_fcns,
3832 error_name,
3833 op=op,
3834 max_val=max_val,
3835 min_val=min_val,
3836 input_shape = a.shape,
3837 output_shape = result_tens.shape,
3838 input_dtype = a.dtype,
3839 output_dtype = result_tens.dtype,
3840 result_tensor = result_tens,
3841 input_list=input_list,
3842 output_list=output_list,
3843 num_operands=num_operands,
3844 )
3845
3846 attr = ts.TosaSerializerAttribute()
3847 if a.dtype == DType.FLOAT:
3848 attr.ClampAttribute(0, 0, min_val, max_val)
3849 else:
3850 attr.ClampAttribute(min_val, max_val, 0, 0)
3851
3852 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003853 return result_tens
3854
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003855 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
3856 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003857 attr = ts.TosaSerializerAttribute()
3858
3859 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
3860
Matthew Haddon848efb42021-09-09 12:30:53 +01003861 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003862 return result_tens
3863
3864 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003865 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
3866 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003867
Matthew Haddon848efb42021-09-09 12:30:53 +01003868 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003869 return result_tens
3870
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003871 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
3872 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
3873
3874 # Invalidate Input/Output list for error if checks.
3875 input_list = [a.name]
3876 output_list = [result_tens.name]
3877 pCount, cCount = op["operands"]
3878 num_operands = pCount + cCount
3879 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3880
3881 TosaErrorValidator.evValidateErrorIfs(
3882 self.ser,
3883 validator_fcns,
3884 error_name,
3885 op=op,
3886 input_shape = a.shape,
3887 output_shape = result_tens.shape,
3888 input_dtype = a.dtype,
3889 output_dtype = result_tens.dtype,
3890 result_tensor = result_tens,
3891 input_list=input_list,
3892 output_list=output_list,
3893 num_operands=num_operands,
3894 )
3895
3896 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07003897 return result_tens
3898
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003899 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
3900 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
3901
3902 # Invalidate Input/Output list for error if checks.
3903 input_list = [a.name]
3904 output_list = [result_tens.name]
3905 pCount, cCount = op["operands"]
3906 num_operands = pCount + cCount
3907 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3908
3909 TosaErrorValidator.evValidateErrorIfs(
3910 self.ser,
3911 validator_fcns,
3912 error_name,
3913 op=op,
3914 input_shape = a.shape,
3915 output_shape = result_tens.shape,
3916 input_dtype = a.dtype,
3917 output_dtype = result_tens.dtype,
3918 result_tensor = result_tens,
3919 input_list=input_list,
3920 output_list=output_list,
3921 num_operands=num_operands,
3922 )
3923
3924 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07003925 return result_tens
3926
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003927 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
3928 if error_name != ErrorIf.WrongInputType:
3929 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01003930
3931 # To store variable length list of input tensors we need to store axis along with it
3932 axis = a[-1]
3933 a = a[:-1]
3934
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003935 result_tens = OutputShaper.concatOp(self.ser, self.rng, axis, *a, error_name=error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003936
Matthew Haddon818ab902021-07-27 09:12:49 +01003937 input_tensor_names = []
3938 for tensor in a:
3939 input_tensor_names.append(tensor.name)
3940
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003941 # Invalidate Input/Output list for error if checks.
3942 input_list = input_tensor_names
3943 output_list = [result_tens.name]
3944 pCount, cCount = op["operands"]
3945 num_operands = pCount + cCount
3946 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3947
3948 TosaErrorValidator.evValidateErrorIfs(
3949 self.ser,
3950 validator_fcns,
3951 error_name,
3952 op=op,
3953 axis=axis,
3954 input_shape = a[0].shape,
3955 output_shape = result_tens.shape,
3956 input_dtype = a[0].dtype,
3957 output_dtype = result_tens.dtype,
3958 inputs=a,
3959 result_tensor = result_tens,
3960 input_list=input_list,
3961 output_list=output_list,
3962 num_operands=num_operands,
3963 )
3964
3965 attr = ts.TosaSerializerAttribute()
3966 attr.AxisAttribute(axis)
3967
3968
3969 self.ser.addOperator(op['op'], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01003970 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07003971
Kevin Chengfe392ce2021-10-18 21:51:55 +00003972 def build_pad(self, op, a, padding, pad_const_int, pad_const_float, validator_fcns=None, error_name=None, qinfo=None):
Matthew Haddone807aae2021-10-11 18:12:58 +01003973 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003974
Kevin Chengfe392ce2021-10-18 21:51:55 +00003975 attr = ts.TosaSerializerAttribute()
3976 attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07003977
Matthew Haddone807aae2021-10-11 18:12:58 +01003978 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00003979 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01003980 output_list = [result_tens.name]
3981 pCount, cCount = op["operands"]
3982 num_operands = pCount + cCount
3983 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3984
3985 TosaErrorValidator.evValidateErrorIfs(
3986 self.ser,
3987 validator_fcns,
3988 error_name,
3989 op=op,
3990 input_shape = a.shape,
3991 output_shape = result_tens.shape,
3992 input_dtype = a.dtype,
3993 output_dtype = result_tens.dtype,
3994 pad=padding,
3995 qinfo=qinfo,
3996 result_tensor = result_tens,
3997 input_list=input_list,
3998 output_list=output_list,
3999 num_operands=num_operands,
4000 )
4001
Kevin Cheng550ccc52021-03-03 11:21:43 -08004002 self.ser.addOperator(
Kevin Chengfe392ce2021-10-18 21:51:55 +00004003 op['op'], input_list, output_list, attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08004004 )
Matthew Haddone86fd342021-09-07 16:12:21 +01004005 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07004006
Matthew Haddone807aae2021-10-11 18:12:58 +01004007 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
4008 result_tens = OutputShaper.reshapeOp(self.ser, self.rng, a, newShape, error_name)
4009
4010 # Invalidate Input/Output list for error if checks.
4011 input_list = [a.name]
4012 output_list = [result_tens.name]
4013 pCount, cCount = op["operands"]
4014 num_operands = pCount + cCount
4015 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4016
4017 TosaErrorValidator.evValidateErrorIfs(
4018 self.ser,
4019 validator_fcns,
4020 error_name,
4021 op=op,
4022 input_shape = a.shape,
4023 output_shape = result_tens.shape,
4024 input_dtype = a.dtype,
4025 output_dtype = result_tens.dtype,
4026 result_tensor = result_tens,
4027 input_list=input_list,
4028 output_list=output_list,
4029 num_operands=num_operands,
4030 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004031
4032 attr = ts.TosaSerializerAttribute()
4033 attr.ReshapeAttribute(newShape)
4034
Matthew Haddone807aae2021-10-11 18:12:58 +01004035 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004036 return result_tens
4037
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004038 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
4039 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
4040
4041 # Invalidate Input/Output list for error if checks.
4042 input_list = [a.name]
4043 output_list = [result_tens.name]
4044 pCount, cCount = op["operands"]
4045 num_operands = pCount + cCount
4046 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4047
4048 TosaErrorValidator.evValidateErrorIfs(
4049 self.ser,
4050 validator_fcns,
4051 error_name,
4052 op=op,
4053 axis=axis,
4054 input_shape = a.shape,
4055 output_shape = result_tens.shape,
4056 input_dtype = a.dtype,
4057 output_dtype = result_tens.dtype,
4058 result_tensor = result_tens,
4059 input_list=input_list,
4060 output_list=output_list,
4061 num_operands=num_operands,
4062 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004063
4064 attr = ts.TosaSerializerAttribute()
4065 attr.AxisAttribute(axis)
4066
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004067 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004068 return result_tens
4069
Matthew Haddone807aae2021-10-11 18:12:58 +01004070 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
4071 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004072
Kevin Chengfe392ce2021-10-18 21:51:55 +00004073 attr = ts.TosaSerializerAttribute()
4074 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07004075
Matthew Haddone807aae2021-10-11 18:12:58 +01004076 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00004077 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01004078 output_list = [result_tens.name]
4079 pCount, cCount = op["operands"]
4080 num_operands = pCount + cCount
4081 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4082
4083 TosaErrorValidator.evValidateErrorIfs(
4084 self.ser,
4085 validator_fcns,
4086 error_name,
4087 op=op,
4088 input_shape = a.shape,
4089 output_shape = result_tens.shape,
4090 perms=perms,
4091 input_dtype = a.dtype,
4092 output_dtype = result_tens.dtype,
4093 result_tensor = result_tens,
4094 input_list=input_list,
4095 output_list=output_list,
4096 num_operands=num_operands,
4097 )
4098
4099
Kevin Chengfe392ce2021-10-18 21:51:55 +00004100 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004101 return result_tens
4102
Matthew Haddone807aae2021-10-11 18:12:58 +01004103 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
4104 result_tens = OutputShaper.sliceOp(self.ser, self.rng, a, start, size, error_name)
4105
4106 # Invalidate Input/Output list for error if checks.
4107 input_list = [a.name]
4108 output_list = [result_tens.name]
4109 pCount, cCount = op["operands"]
4110 num_operands = pCount + cCount
4111 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4112
4113 TosaErrorValidator.evValidateErrorIfs(
4114 self.ser,
4115 validator_fcns,
4116 error_name,
4117 op=op,
4118 input_shape = a.shape,
4119 output_shape = result_tens.shape,
4120 input_dtype = a.dtype,
4121 output_dtype = result_tens.dtype,
4122 start=start,
4123 size=size,
4124 result_tensor = result_tens,
4125 input_list=input_list,
4126 output_list=output_list,
4127 num_operands=num_operands,
4128 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004129
4130 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01004131 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07004132
Matthew Haddone807aae2021-10-11 18:12:58 +01004133 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004134 return result_tens
4135
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004136 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
4137 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
4138
4139 # Invalidate Input/Output list for error if checks.
4140 input_list = [a.name]
4141 output_list = [result_tens.name]
4142 pCount, cCount = op["operands"]
4143 num_operands = pCount + cCount
4144 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4145
4146 TosaErrorValidator.evValidateErrorIfs(
4147 self.ser,
4148 validator_fcns,
4149 error_name,
4150 op=op,
4151 input_shape = a.shape,
4152 output_shape = result_tens.shape,
4153 input_dtype = a.dtype,
4154 output_dtype = result_tens.dtype,
4155 result_tensor = result_tens,
4156 input_list=input_list,
4157 output_list=output_list,
4158 num_operands=num_operands,
4159 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004160
4161 attr = ts.TosaSerializerAttribute()
4162 attr.TileAttribute(multiples)
4163
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004164 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004165 return result_tens
4166
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004167 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004168
4169 # Create a new indicies tensor
4170 # here with data that doesn't exceed the dimensions of the values tensor
4171
Kevin Cheng550ccc52021-03-03 11:21:43 -08004172 K = values.shape[1] # K
4173 W = self.randInt(
4174 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
4175 ) # W
4176 indicies_arr = np.int32(
4177 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
4178 ) # (N, W)
4179 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004180
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004181 result_tens = OutputShaper.gatherOp(self.ser, self.rng, values, indicies, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004182
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004183 # Invalidate Input/Output list for error if checks.
4184 input_list = [values.name, indicies.name]
4185 output_list = [result_tens.name]
4186 pCount, cCount = op["operands"]
4187 num_operands = pCount + cCount
4188 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4189
4190 TosaErrorValidator.evValidateErrorIfs(
4191 self.ser,
4192 validator_fcns,
4193 error_name,
4194 op=op,
4195 input_shape = values.shape,
4196 output_shape = result_tens.shape,
4197 input_dtype = values.dtype,
4198 output_dtype = result_tens.dtype,
4199 result_tensor = result_tens,
4200 input_list=input_list,
4201 output_list=output_list,
4202 num_operands=num_operands,
4203 )
4204
4205 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004206
4207 return result_tens
4208
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004209 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08004210
4211 # Create a new indicies tensor
4212 # here with data that doesn't exceed the dimensions of the values_in tensor
4213
Kevin Cheng550ccc52021-03-03 11:21:43 -08004214 K = values_in.shape[1] # K
4215 W = input.shape[1] # W
4216 indicies_arr = np.int32(
4217 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
4218 ) # (N, W)
4219 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004220
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004221 result_tens = OutputShaper.scatterOp(self.ser, self.rng, values_in, indicies, input, error_name)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004222
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004223 # Invalidate Input/Output list for error if checks.
4224 input_list = [values_in.name, indicies.name, input.name]
4225 output_list = [result_tens.name]
4226 pCount, cCount = op["operands"]
4227 num_operands = pCount + cCount
4228 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4229
4230 TosaErrorValidator.evValidateErrorIfs(
4231 self.ser,
4232 validator_fcns,
4233 error_name,
4234 op=op,
4235 input_shape = input.shape,
4236 output_shape = result_tens.shape,
4237 input_dtype = input.dtype,
4238 output_dtype = result_tens.dtype,
4239 result_tensor = result_tens,
4240 input_list=input_list,
4241 output_list=output_list,
4242 num_operands=num_operands,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004243 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08004244
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004245 self.ser.addOperator(op['op'], input_list, output_list)
4246
Kevin Cheng77d0f762020-11-24 10:26:32 -08004247 return result_tens
4248
Matthew Haddon848efb42021-09-09 12:30:53 +01004249
Kevin Cheng550ccc52021-03-03 11:21:43 -08004250 def build_resize(
4251 self,
4252 op,
4253 input,
4254 mode,
4255 stride,
4256 offset,
4257 shift,
4258 stride_fp,
4259 offset_fp,
4260 output_dims,
4261 input_dtype,
4262 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01004263 validator_fcns,
4264 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004265 ):
4266 result_tens = OutputShaper.resizeOp(
4267 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004268 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004269 input,
4270 mode,
4271 stride,
4272 offset,
4273 shift,
4274 stride_fp,
4275 offset_fp,
4276 output_dims,
4277 input_dtype,
4278 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01004279 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08004280 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004281
Matthew Haddon848efb42021-09-09 12:30:53 +01004282 # Invalidate Input/Output list for error if checks.
4283 input_list = [input.name]
4284 output_list = [result_tens.name]
4285 pCount, cCount = op["operands"]
4286 num_operands = pCount + cCount
4287 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
Matthew Haddone86fd342021-09-07 16:12:21 +01004288
Matthew Haddon848efb42021-09-09 12:30:53 +01004289 TosaErrorValidator.evValidateErrorIfs(
4290 self.ser,
4291 validator_fcns,
4292 error_name,
4293 op=op,
4294 mode=mode,
4295 shift=shift,
4296 input_dtype=input_dtype,
4297 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004298 input_shape=input.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01004299 output_shape=output_dims,
4300 offset=offset,
4301 offset_fp=offset_fp,
4302 stride=stride,
4303 stride_fp=stride_fp,
4304 input_list=input_list,
4305 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004306 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01004307 num_operands=num_operands,
4308 )
Matthew Haddone86fd342021-09-07 16:12:21 +01004309
Eric Kunzee5e26762020-10-13 16:11:07 -07004310 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08004311
Kevin Cheng550ccc52021-03-03 11:21:43 -08004312 attr.ResizeAttribute(
4313 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
4314 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004315
Matthew Haddon848efb42021-09-09 12:30:53 +01004316 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004317 return result_tens
4318
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004319 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
4320 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
4321 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004322 self.ser.addOperator(
4323 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
4324 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004325 return result_tens
4326
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004327 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07004328 self.ser.addOutputTensor(val)
4329 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07004330
4331 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004332 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
4333 result_tens = OutputShaper.typeConversionOp(self.ser, self.rng, val, out_dtype, error_name)
4334
4335 # Invalidate Input/Output list for error if checks.
4336 input_list = [val.name]
4337 output_list = [result_tens.name]
4338 pCount, cCount = op["operands"]
4339 num_operands = pCount + cCount
4340 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4341
4342 TosaErrorValidator.evValidateErrorIfs(
4343 self.ser,
4344 validator_fcns,
4345 error_name,
4346 op=op,
4347 input_shape = val.shape,
4348 output_shape = result_tens.shape,
4349 input_dtype = val.dtype,
4350 output_dtype = result_tens.dtype,
4351 result_tensor = result_tens,
4352 input_list=input_list,
4353 output_list=output_list,
4354 num_operands=num_operands,
4355 )
4356
4357 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004358 return result_tens
4359
Matthew Haddonc2025212021-10-08 21:21:05 +01004360 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel, validator_fcns, error_name):
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004361 result_tens = OutputShaper.typeConversionOp(self.ser, self.rng, val, out_dtype, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004362
4363 if per_channel:
4364 nc = val.shape[-1]
4365 else:
4366 nc = 1
4367
4368 in_type_width = self.typeWidth(val.dtype)
4369 out_type_width = self.typeWidth(out_dtype)
4370
Kevin Cheng3a478572021-01-22 17:21:02 -08004371 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004372 input_zp = self.randInt(-128, 128)
4373 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07004374 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004375 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07004376 in_type_width = in_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01004377 elif error_name == ErrorIf.InputZeroPointNotZero:
4378 input_zp = self.randInt(-128, 128)
4379 if input_zp == 0:
4380 input_zp = input_zp + self.rng.integers(1, 10)
4381 in_type_width = in_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004382 else:
4383 input_zp = 0
4384
Kevin Cheng3a478572021-01-22 17:21:02 -08004385 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004386 output_zp = self.randInt(-128, 128)
4387 out_type_width = out_type_width + 1
4388 elif out_dtype == DType.UINT8:
4389 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07004390 out_type_width = out_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01004391 elif error_name == ErrorIf.OutputZeroPointNotZero:
4392 output_zp = self.randInt(-128, 128)
4393 if output_zp == 0:
4394 output_zp = output_zp + self.rng.integers(1, 10)
4395 out_type_width = out_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004396 else:
4397 output_zp = 0
4398
4399 # Calculate scale based on:
4400 # scale = a *(2^output_width)/(2^input_width))
4401
4402 a = np.float32(self.rng.random(size=[nc]))
4403 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
4404
4405 if scale32:
4406 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01004407 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07004408 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
4409 else:
4410 # Cap the scaling at 2^15 - 1 for scale16
4411 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
4412
Kevin Cheng550ccc52021-03-03 11:21:43 -08004413 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07004414
4415 multiplier_arr = np.int32(np.zeros(shape=[nc]))
4416 shift_arr = np.int32(np.zeros(shape=[nc]))
4417
4418 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004419 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
4420 scale_arr[i], scale32
4421 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004422
Kevin Cheng550ccc52021-03-03 11:21:43 -08004423 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07004424
Matthew Haddonc2025212021-10-08 21:21:05 +01004425 # Invalidate Input/Output list for error if checks.
4426 input_list = [val.name]
4427 output_list = [result_tens.name]
4428 pCount, cCount = op["operands"]
4429 num_operands = pCount + cCount
4430 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
4431
4432 qinfo = (input_zp, output_zp)
4433 TosaErrorValidator.evValidateErrorIfs(
4434 self.ser,
4435 validator_fcns,
4436 error_name,
4437 op=op,
4438 input_dtype=val.dtype,
4439 output_dtype=out_dtype,
4440 input_shape=val.shape,
4441 qinfo=qinfo,
4442 scale32 = scale32,
4443 double_round = double_round,
4444 input_list=input_list,
4445 output_list=output_list,
4446 result_tensor=result_tens,
4447 num_operands=num_operands,
4448 )
4449
Eric Kunzee5e26762020-10-13 16:11:07 -07004450 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004451 attr.RescaleAttribute(
4452 input_zp,
4453 output_zp,
4454 multiplier_arr,
4455 shift_arr,
4456 scale32,
4457 double_round,
4458 per_channel,
4459 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004460
Matthew Haddonc2025212021-10-08 21:21:05 +01004461 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004462 return result_tens
4463
Matthew Haddon630c17c2021-10-14 15:05:41 +01004464 def build_cond_if_const(self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004465 # For cond_if with constants, we're supplied with then/else tensors that we ignore
4466 # (except for the generated shap) and the condition. Build Then/Else blocks
4467 # and fill them with const nodes for the body.
4468
4469 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08004470 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07004471
4472 # Make then/else tensors
4473 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01004474
4475 # Create an incorrect output shape for error_if tests
4476 if error_name in [ErrorIf.CondIfOutputListThenGraphMismatch, ErrorIf.CondIfOutputListElseGraphMismatch]:
4477 incorrect_shape = deepcopy(then_tens.shape)
4478 for i in range(len(incorrect_shape)):
4479 incorrect_shape[i] = incorrect_shape[i] + self.rng.choice([-3, -2, 2, 3])
4480 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
4481
Jeremy Johnson18e26662021-07-22 16:15:29 +01004482 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
4483 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07004484
4485 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08004486 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07004487
4488 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08004489 then_block = "THEN_BLOCK"
4490 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07004491 attr = ts.TosaSerializerAttribute()
4492 attr.CondIfAttribute(then_block, else_block)
4493
4494 # Finally, build the op and the two blocks
Matthew Haddon848efb42021-09-09 12:30:53 +01004495 self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004496
4497 self.ser.startBasicBlock(then_block)
4498 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01004499 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
4500 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
4501 else:
4502 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004503 self.ser.addOutputTensor(then_tens)
4504
4505 self.ser.startBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004506 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
4507 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
4508 else:
4509 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004510 self.ser.addOutputTensor(else_tens)
4511
Matthew Haddon630c17c2021-10-14 15:05:41 +01004512 TosaErrorValidator.evValidateErrorIfs(
4513 self.ser,
4514 validator_fcns,
4515 error_name,
4516 op=op,
4517 basicBlocks=self.ser.basicBlocks
4518 )
4519
Eric Kunzee5e26762020-10-13 16:11:07 -07004520 return result_tens
4521
Matthew Haddon630c17c2021-10-14 15:05:41 +01004522 def build_cond_if_binary(self, op, a, b, cond, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004523 # For cond_if with a binary op in the then/else blocks, take a and b and
4524 # alternately add or subtract them based on the condition
4525
4526 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08004527 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07004528
Kevin Cheng550ccc52021-03-03 11:21:43 -08004529 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004530
4531 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08004532 then_block = "THEN_BLOCK"
4533 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07004534 attr = ts.TosaSerializerAttribute()
4535 attr.CondIfAttribute(then_block, else_block)
4536
Matthew Haddon630c17c2021-10-14 15:05:41 +01004537 if error_name in [ErrorIf.CondIfInputListThenGraphMismatch, ErrorIf.CondIfInputListElseGraphMismatch,
4538 ErrorIf.CondIfOutputListElseGraphMismatch, ErrorIf.CondIfOutputListThenGraphMismatch]:
4539 incorrect_shape = a.shape.copy()
4540 for i in range(len(incorrect_shape)):
4541 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
4542 incorrect_block_input = deepcopy(a)
4543 incorrect_block_input.shape = incorrect_shape
4544
4545
Eric Kunzee5e26762020-10-13 16:11:07 -07004546 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08004547 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01004548 op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08004549 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004550
Les Bell6040b4d2021-10-11 12:50:31 +01004551 if a.dtype in (DType.FLOAT, DType.INT32):
4552 then_op, else_op = Op.ADD, Op.SUB
4553 elif a.dtype in (DType.INT8, DType.INT16):
4554 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
4555 else:
4556 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07004557
Les Bell6040b4d2021-10-11 12:50:31 +01004558 for block, op in ((then_block, then_op), (else_block, else_op)):
4559 self.ser.startBasicBlock(block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004560 if ((error_name == ErrorIf.CondIfInputListThenGraphMismatch and block == then_block) or
4561 (error_name == ErrorIf.CondIfInputListElseGraphMismatch and block == else_block)):
4562 self.ser.addInputTensor(incorrect_block_input)
4563 self.ser.addInputTensor(b)
4564 tens = self.ser.addOutput(a.shape, a.dtype)
4565 elif ((error_name == ErrorIf.CondIfOutputListThenGraphMismatch and block == then_block) or
4566 (error_name == ErrorIf.CondIfOutputListElseGraphMismatch and block == else_block)):
4567 self.ser.addInputTensor(a)
4568 self.ser.addInputTensor(b)
4569 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
4570 else:
4571 self.ser.addInputTensor(a)
4572 self.ser.addInputTensor(b)
4573 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01004574 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07004575
Matthew Haddon630c17c2021-10-14 15:05:41 +01004576 TosaErrorValidator.evValidateErrorIfs(
4577 self.ser,
4578 validator_fcns,
4579 error_name,
4580 op=op,
4581 a=a,
4582 b=b,
4583 basicBlocks=self.ser.basicBlocks
4584 )
4585
Eric Kunzee5e26762020-10-13 16:11:07 -07004586 return result_tens
4587
Matthew Haddon630c17c2021-10-14 15:05:41 +01004588 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004589 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07004590
Kevin Cheng550ccc52021-03-03 11:21:43 -08004591 cond_block = "COND_BLOCK"
4592 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07004593
4594 attr = ts.TosaSerializerAttribute()
4595 attr.WhileLoopAttribute(cond_block, body_block)
4596
4597 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08004598 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004599 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08004600 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07004601
4602 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004603 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
4604 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004605 if error_name == ErrorIf.InputListOutputListMismatch:
4606 incorrect_acc = deepcopy(acc)
4607 for i in range(len(incorrect_acc.shape)):
4608 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
4609 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
4610 else:
4611 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004612
4613 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08004614 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01004615 op['op'],
Kevin Cheng550ccc52021-03-03 11:21:43 -08004616 [iter.name, a.name, acc.name],
4617 [iter_out.name, a_out.name, acc_out.name],
4618 attr,
4619 )
Kevin Chengb227ae52021-09-02 13:43:17 -07004620 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07004621
Matthew Haddon630c17c2021-10-14 15:05:41 +01004622 if error_name in [ErrorIf.InputListCondGraphMismatch, ErrorIf.InputListBodyGraphInputMismatch, ErrorIf.InputListBodyGraphOutputMismatch]:
4623 incorrect_iter = deepcopy(iter)
4624 for i in range(len(incorrect_iter.shape)):
4625 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
4626 if len(incorrect_iter.shape) == 0:
4627 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
4628
4629 incorrect_acc = deepcopy(acc)
4630 for i in range(len(incorrect_acc.shape)):
4631 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
4632
Eric Kunzee5e26762020-10-13 16:11:07 -07004633 # COND block (input: iter, output: cond_tens )
4634 self.ser.startBasicBlock(cond_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004635 if error_name == ErrorIf.InputListCondGraphMismatch:
4636 self.ser.addInputTensor(incorrect_iter)
4637 self.ser.addInputTensor(a)
4638 self.ser.addInputTensor(incorrect_acc)
4639 else:
4640 self.ser.addInputTensor(iter)
4641 self.ser.addInputTensor(a)
4642 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004643 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01004644
4645 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
4646 cond_tens = self.ser.addOutput([], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT]))
4647 else:
4648 cond_tens = self.ser.addOutput([], DType.BOOL)
4649
Kevin Cheng550ccc52021-03-03 11:21:43 -08004650 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07004651
4652 # BODY block (input: a, acc, iter, output: a, acc, iter)
4653 # Note that local intermediate tensors need to be declared here for the outputs
4654 self.ser.startBasicBlock(body_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01004655 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
4656 self.ser.addInputTensor(incorrect_iter)
4657 self.ser.addInputTensor(a)
4658 self.ser.addInputTensor(incorrect_acc)
4659 else:
4660 self.ser.addInputTensor(iter)
4661 self.ser.addInputTensor(a)
4662 self.ser.addInputTensor(acc)
4663
Kevin Cheng550ccc52021-03-03 11:21:43 -08004664 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01004665
4666 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
4667 iter_body_out = self.ser.addIntermediate(incorrect_iter.shape, incorrect_iter.dtype)
4668 acc_body_out = self.ser.addIntermediate(incorrect_acc.shape, incorrect_acc.dtype)
4669 else:
4670 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
4671 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
4672
Eric Kunzee5e26762020-10-13 16:11:07 -07004673 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
4674 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
4675 self.ser.addOutputTensor(iter_body_out)
4676 self.ser.addOutputTensor(a)
4677 self.ser.addOutputTensor(acc_body_out)
4678
Matthew Haddon630c17c2021-10-14 15:05:41 +01004679 TosaErrorValidator.evValidateErrorIfs(
4680 self.ser,
4681 validator_fcns,
4682 error_name,
4683 op=op,
4684 basicBlocks=self.ser.basicBlocks
4685 )
4686
Eric Kunzee5e26762020-10-13 16:11:07 -07004687 return acc_out
4688
Matthew Haddon1c00b712021-10-01 15:51:03 +01004689 def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
4690 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
4691 default_test_rank_range = range(1, 5)
4692 if not shapeFilter:
4693 shapeFilter = [None]
4694
4695 # Calculate the filters based on what is requested and what the operator allows
4696 rmin, rmax = op["rank"]
4697 if rankFilter is not None:
4698 cleanRankFilter = []
4699 # Ensure rankFilter values are allowed by operator
4700 for rank in rankFilter:
4701 if rank >= rmin and rank <= rmax:
4702 cleanRankFilter.append(rank)
4703 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01004704 # Ensure default behaviour is bounded by default range or by operator,
4705 # whichever is the smaller range of ranks.
4706 opRankRange = range(rmin, rmax + 1)
4707 cleanRankFilter = opRankRange if len(opRankRange) <= len(default_test_rank_range) else default_test_rank_range
Matthew Haddon1c00b712021-10-01 15:51:03 +01004708 else:
4709 cleanRankFilter = range(rmin, rmax + 1)
4710
4711 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004712
Matthew Haddon1c00b712021-10-01 15:51:03 +01004713 if dtypeFilter is not None:
4714 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01004715 # Create list of operator dtypes filtered by requested dtypes
4716 for dtype in dtypes:
4717 if dtype in dtypeFilter or (isinstance(dtype, list) and dtype[0] in dtypeFilter):
Matthew Haddon1c00b712021-10-01 15:51:03 +01004718 cleanDtypeFilter.append(dtype)
4719 else:
4720 cleanDtypeFilter = dtypes
4721
4722 if testType == 'positive':
4723 filterDict = {
4724 'shapeFilter': shapeFilter,
4725 'rankFilter': cleanRankFilter,
4726 'dtypeFilter': cleanDtypeFilter
4727 }
4728 return filterDict
4729 elif testType == 'negative':
Matthew Haddone807aae2021-10-11 18:12:58 +01004730 if validator is not None:
4731 validator_info = validator(check=False, op=op)
4732 else:
4733 return None
4734
Matthew Haddon1c00b712021-10-01 15:51:03 +01004735 error_arguments = validator_info['param_reqs']
4736
4737 #Set parameters as required
4738 if error_arguments['rank'] != None:
4739 rankFilter = error_arguments['rank']
4740 else:
4741 rankFilter = cleanRankFilter
4742
4743 if error_arguments['dtype'] != None:
4744 dtypeFilter = error_arguments['dtype']
4745 else:
4746 dtypeFilter = cleanDtypeFilter
4747
4748 if error_arguments['shape'] != None:
4749 shapeFilter = error_arguments['shape']
4750 else:
4751 shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small
4752
4753 filterDict = {
4754 'shapeFilter': shapeFilter,
4755 'rankFilter': rankFilter,
4756 'dtypeFilter': dtypeFilter
4757 }
4758 return filterDict
4759
4760
Kevin Cheng550ccc52021-03-03 11:21:43 -08004761 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01004762 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08004763 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004764
4765 try:
4766 op = self.TOSA_OP_LIST[opName]
4767 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004768 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07004769
4770 # Initialize a new random number generator
4771 self.rng = np.random.default_rng(self.random_seed)
4772
Kevin Cheng550ccc52021-03-03 11:21:43 -08004773 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004774
Eric Kunzee5e26762020-10-13 16:11:07 -07004775 # Test list consists of a tuple of:
4776 # (opName, testNameStr, dtype, shapeList, argumentsList)
4777 testList = []
Matthew Haddon1c00b712021-10-01 15:51:03 +01004778 if testType == 'negative' and "error_if_validators" in op:
4779 error_if_validators = op["error_if_validators"]
4780 else:
4781 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07004782
Matthew Haddon1c00b712021-10-01 15:51:03 +01004783 for validator in error_if_validators:
4784 if validator is not None:
4785 error_name = validator(check=False, op=op)['error_name']
Matthew Haddon1c00b712021-10-01 15:51:03 +01004786 else:
4787 error_name = None
4788
4789 filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator)
Matthew Haddone807aae2021-10-11 18:12:58 +01004790 if filterDict == None:
4791 return []
Matthew Haddon1c00b712021-10-01 15:51:03 +01004792 cleanRankFilter = filterDict['rankFilter']
4793 cleanDtypeFilter = filterDict['dtypeFilter']
4794 cleanShapeFilter = filterDict['shapeFilter']
4795 #print(f"Filters: S {shapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
4796
4797 for r in cleanRankFilter:
Kevin Cheng1533b852021-09-01 12:51:58 -07004798 if opName.startswith("conv3d"):
4799 assert r == 5, "conv3d test must have input rank == 5"
Matthew Haddon1c00b712021-10-01 15:51:03 +01004800 for t in cleanDtypeFilter:
4801 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01004802 # Filter out by rank
4803 if shape is not None and len(shape) != r:
4804 continue
Matthew Haddon74567092021-07-16 15:38:20 +01004805 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01004806 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004807
Matthew Haddon74567092021-07-16 15:38:20 +01004808 shapeStr = self.shapeStr(shapeList[0])
4809 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07004810
Matthew Haddon74567092021-07-16 15:38:20 +01004811 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
4812 argList = []
4813 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01004814 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004815 else:
Matthew Haddon74567092021-07-16 15:38:20 +01004816 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07004817
Matthew Haddon74567092021-07-16 15:38:20 +01004818 for argStr, args in argList:
Matthew Haddon1c00b712021-10-01 15:51:03 +01004819 if testType == 'positive':
4820 if argStr:
4821 testStr = "{}_{}_{}_{}".format(
4822 opName, shapeStr, typeStr, argStr
4823 )
4824 else:
4825 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
4826 elif testType == 'negative':
Matthew Haddone86fd342021-09-07 16:12:21 +01004827 if argStr:
4828 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
4829 opName, error_name, shapeStr, typeStr, argStr
4830 )
4831 else:
4832 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
Matthew Haddon1c00b712021-10-01 15:51:03 +01004833
4834 testList.append((opName, testStr, t, error_name, shapeList, args))
4835
4836 if testType == 'positive':
4837 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
4838 if "invalid_test_validators" in op:
4839 invalid_test_validators = op["invalid_test_validators"]
4840 clean_testList = []
4841 for test in testList:
4842 for validator_fcn in invalid_test_validators:
4843 remove_test = False
4844 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
4845 remove_test = True
4846 if not remove_test:
4847 clean_testList.append(test)
4848 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07004849
4850 return testList
4851
Matthew Haddone86fd342021-09-07 16:12:21 +01004852
4853 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07004854 try:
4855 op = self.TOSA_OP_LIST[opName]
4856 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004857 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07004858
4859 # Create a serializer
4860 self.createSerializer(opName, testStr)
4861
Kevin Cheng550ccc52021-03-03 11:21:43 -08004862 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01004863 if "error_if_validators" in op:
4864 error_if_validators = op["error_if_validators"]
4865 else:
4866 error_if_validators = None
4867
Kevin Cheng550ccc52021-03-03 11:21:43 -08004868 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07004869 num_operands = pCount + cCount
4870
4871 if isinstance(dtype_or_dtypeList, list):
4872 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07004873 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01004874 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07004875 else:
4876 dtypeList = [dtype_or_dtypeList] * (num_operands)
4877
Kevin Cheng93a16282021-08-31 16:14:03 -07004878 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01004879 assert (
4880 len(shapeList) == num_operands
4881 ), "shapeList length {} must match number of operands {}".format(
4882 len(shapeList), num_operands
4883 )
4884 assert (
4885 len(dtypeList) == num_operands
4886 ), "dtypeList length {} must match number of operands {}".format(
4887 len(dtypeList), num_operands
4888 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004889
4890 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004891 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07004892 except KeyError:
4893 qgen = None
4894
4895 # Build the random tensor operands and the test
4896 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08004897
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004898 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01004899
4900 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004901 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01004902 else:
4903 qinfo = None
4904
4905 try:
4906 if error_if_validators is None:
4907 if qinfo is not None:
4908 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
4909 else:
4910 resultName = build_fcn(self, op, *tens, *testArgs)
4911 else:
4912 if qinfo is not None:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004913 resultName = build_fcn(self, op, *tens, *testArgs, validator_fcns=error_if_validators, error_name=error_name, qinfo=qinfo)
Matthew Haddon1c00b712021-10-01 15:51:03 +01004914 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004915 resultName = build_fcn(self, op, *tens, *testArgs, validator_fcns=error_if_validators, error_name=error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01004916 except TypeError as e:
4917 print(
4918 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
4919 build_fcn, tens, testArgs
4920 )
4921 )
4922 raise e
4923
4924 if resultName is None:
4925 print("Invalid ERROR_IF tests created")
4926
4927 # Save the serialized test
4928 self.serialize("test")
4929
4930
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004931 def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
Matthew Haddon1c00b712021-10-01 15:51:03 +01004932 pCount, cCount = op["operands"]
4933
4934 tens = []
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004935 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name == None:
Jeremy Johnsonef509a42021-09-07 13:59:47 +01004936 # Make sure the operation does not cause value saturation - where
4937 # the number wraps due to limited number of bits to store the answer
4938 assert (
4939 pCount == 2 and cCount == 0
4940 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01004941 placeholders = []
4942 add = (op["op"] == Op.ADD)
4943 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
4944 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
4945 if add:
4946 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
4947 else:
4948 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
4949
4950 # Work out the saturation limits
4951 max_i32 = (1 << 31)-1
4952 min_i32 = -(1 << 31)
4953 max_arr = np.full(shapeList[1], max_i32)
4954 min_arr = np.full(shapeList[1], min_i32)
4955
4956 # Find how much values exceed the maximum/minimums
4957 sat_max_arr = np.maximum(res_arr - max_arr, 0)
4958 sat_min_arr = np.minimum(res_arr - min_arr, 0)
4959
4960 if not add:
4961 # Swap saturation values and negate values as we need to perform opposite operations
4962 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
4963
4964 # Create new array of unsaturated values by clipping values as needed
4965 b_unsat_arr = b_arr
4966 if (sat_max_arr != 0).any():
4967 # Clip values that cause saturation
4968 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
4969 # Reduce axes in unsaturated tensor to match original tensor
4970 for axis, dim in enumerate(b_arr.shape):
4971 if dim != b_unsat_arr.shape[axis]:
4972 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
4973 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
4974
4975 if (sat_min_arr != 0).any():
4976 # Clip values that cause saturation
4977 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
4978 # Reduce axes in unsaturated tensor to match original tensor
4979 for axis, dim in enumerate(b_arr.shape):
4980 if dim != b_unsat_arr.shape[axis]:
4981 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
4982 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
4983
4984 placeholders.append(
4985 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
4986 )
4987 placeholders.append(
4988 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
4989 )
4990
4991 tens.extend(placeholders)
Jeremy Johnson8c06a652021-10-20 15:51:11 +01004992 elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[0] == DType.INT32:
4993 # Limit input tensors with cond_if_binary or while_loop to stop
4994 # saturation of add/sub ops
4995 pRemain = pCount
4996 placeholders = []
4997 for idx, shape in enumerate(shapeList[:]):
4998 arr = self.getRandTensor(shapeList[idx], DType.INT16)
4999 if pRemain > 0:
5000 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
5001 pRemain -= 1
5002 else:
5003 placeholders.append(self.ser.addConst(shape, dtypeList[idx], arr))
5004
5005 tens.extend(placeholders)
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005006 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
5007 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005008 assert (
5009 pCount == 2 and cCount == 0
5010 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08005011
5012 placeholders = []
5013 for idx, shape in enumerate(shapeList[:]):
5014 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07005015 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08005016 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07005017 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08005018 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07005019 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08005020 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005021 elif error_name == ErrorIf.WrongInputType:
5022 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Chengaee1fac2020-11-11 13:54:06 -08005023 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005024 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08005025 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005026 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07005027 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08005028
5029 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01005030 elif op["op"] == Op.SELECT:
5031 # Set datatype of condition tensor to boolean
5032 dtypeList[0] = DType.BOOL
5033 tens.extend(
5034 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
5035 )
5036 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005037 elif op["op"] == Op.INTDIV and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005038 assert (
5039 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01005040 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005041
5042 placeholders = []
5043
Matthew Haddon459443c2021-08-23 16:43:13 +01005044 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005045 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07005046 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005047 while True:
5048 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
5049 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
5050
5051 if (divisor_arr == 0).any():
5052 continue
5053
Kevin Cheng47315e12021-05-13 17:41:28 -07005054 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005055 continue
5056
5057 break
5058
5059 placeholders.append(
5060 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
5061 )
5062 placeholders.append(
5063 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
5064 )
5065
5066 tens.extend(placeholders)
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005067 elif op["op"] == Op.MUL and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005068 assert (
5069 pCount == 2 and cCount == 0
5070 ), "Op.MUL must have 2 placeholders, 0 consts"
5071
5072 if dtypeList[0] == DType.FLOAT:
5073 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
5074 else:
5075 placeholders = []
5076
5077 # Make sure multiply result in int32 range
5078 shift = testArgs[0]
5079 if dtypeList[0] == DType.INT8:
5080 num_bits = 8
5081 elif dtypeList[0] == DType.INT16:
5082 num_bits = 16
5083 elif dtypeList[0] == DType.INT32:
5084 num_bits = 32
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005085 elif error_name == ErrorIf.WrongInputType:
5086 num_bits = 8
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005087 else:
5088 raise Exception("OpMul: invalid input dtype")
5089
5090 for idx, shape in enumerate(shapeList[:]):
5091 low = -(2 ** (num_bits - 1))
5092 high = (2 ** (num_bits - 1)) - 1
5093
5094 a_arr = np.int32(
5095 self.rng.integers(low=low, high=high, size=shapeList[0])
5096 )
5097 b_arr = np.int32(
5098 self.rng.integers(low=low, high=high, size=shapeList[1])
5099 )
5100
5101 i = 0
5102 while True:
5103
5104 a_arr_64 = a_arr.astype(np.int64)
5105 b_arr_64 = b_arr.astype(np.int64)
5106
5107 if shift > 0:
5108 rounding = 1 << (shift - 1)
5109 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
5110 else:
5111 result_arr = a_arr_64 * b_arr_64
5112
5113 if (result_arr > -(2 ** 31)).all() and (
5114 result_arr <= ((2 ** 31) - 1)
5115 ).all():
5116 break
5117
5118 i = i + 1
5119 a_arr = a_arr // 2
5120 b_arr = b_arr // 2
5121
5122 placeholders.append(
5123 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
5124 )
5125 placeholders.append(
5126 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
5127 )
5128
5129 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01005130 elif op["op"] == Op.CONCAT:
5131 count = len(shapeList) - self.args.num_const_inputs_concat
5132 if count < 1:
5133 count = 1
5134 if self.args.num_const_inputs_concat == 0:
5135 count = len(shapeList)
5136
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005137 # Ensure axis is an int
5138 testArgs[0] = int(testArgs[0])
5139
5140 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0], error_name)
5141
Matthew Haddon818ab902021-07-27 09:12:49 +01005142 tens.extend(
5143 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
5144 )
5145 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08005146 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07005147 tens.extend(
5148 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
5149 )
5150 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07005151
Matthew Haddon1c00b712021-10-01 15:51:03 +01005152 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07005153
5154 def createDynamicOpLists(self):
5155
5156 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07005157 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005158
Kevin Cheng1533b852021-09-01 12:51:58 -07005159 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005160 testName = "conv2d_{}x{}".format(k[0], k[1])
5161 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
5162 self.TOSA_OP_LIST[testName]["filter"] = k
5163 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07005164
Kevin Cheng550ccc52021-03-03 11:21:43 -08005165 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
5166 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
5167 "depthwise_conv2d_TEMPLATE"
5168 ].copy()
5169 self.TOSA_OP_LIST[testName]["filter"] = k
5170 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07005171
Kevin Cheng550ccc52021-03-03 11:21:43 -08005172 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
5173 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
5174 "transpose_conv2d_TEMPLATE"
5175 ].copy()
5176 self.TOSA_OP_LIST[testName]["filter"] = k
5177 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07005178
Kevin Cheng1533b852021-09-01 12:51:58 -07005179 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
5180 for k in KERNELS_3D:
5181 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
5182 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
5183 self.TOSA_OP_LIST[testName]["filter"] = k
5184 self.TOSA_OP_LIST[testName]["template"] = False
5185
Eric Kunzee5e26762020-10-13 16:11:07 -07005186 # Delete any templates after having created any dynamic ops
5187 # This is a two-pass operation because it's bad practice to delete
5188 # keys from dictionaries while iterating
5189 keyList = []
5190 for k in self.TOSA_OP_LIST:
5191 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005192 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07005193 keyList.append(k)
5194 continue
5195 except KeyError:
5196 pass
5197
5198 for k in keyList:
5199 del self.TOSA_OP_LIST[k]
5200
5201 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005202 """Fill in default fields for ops if they aren't already specified.
5203 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07005204 for op in self.TOSA_OP_LIST:
5205
5206 # Required fields
5207 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005208 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005209 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005210 raise Exception(
5211 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
5212 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005213
5214 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005215 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005216 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005217 raise Exception(
5218 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
5219 op
5220 )
5221 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005222
5223 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005224 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005225 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005226 raise Exception(
5227 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
5228 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005229
5230 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005231 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005232 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005233 raise Exception(
5234 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
5235 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005236
5237 # Put in default rank range, if missing
5238 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005239 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005240 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005241 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07005242
5243 # Tensor operator list
5244 # 'op': op name
5245 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08005246 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
5247 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07005248 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
5249 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08005250 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07005251
Kevin Cheng550ccc52021-03-03 11:21:43 -08005252 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
5253 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07005254
Kevin Cheng550ccc52021-03-03 11:21:43 -08005255 TYPE_BOOL = [DType.BOOL]
5256 TYPE_FI32 = [DType.FLOAT, DType.INT32]
5257 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
5258 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07005259
Kevin Cheng550ccc52021-03-03 11:21:43 -08005260 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07005261
Kevin Cheng1533b852021-09-01 12:51:58 -07005262 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07005263 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07005264 [DType.INT8, DType.INT8, DType.INT32],
5265 [DType.INT16, DType.INT8, DType.INT48],
5266 DType.FLOAT,
5267 ]
5268
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01005269 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07005270
5271 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08005272 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08005273 "argmax": {
5274 "op": Op.ARGMAX,
5275 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005276 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005277 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
5278 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005279 "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evArgmaxOutputRankMismatch,
5280 TosaErrorValidator.evArgmaxOutputShapeMismatch, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
5281 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005282 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005283 "avg_pool2d": {
5284 "op": Op.AVG_POOL2D,
5285 "operands": (1, 0),
5286 "rank": (4, 4),
5287 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
5288 "qgen": TosaQuantGen.qgUnary,
5289 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005290 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
5291 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
5292 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5293 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
5294 TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005295 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005296 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08005297 "conv2d_TEMPLATE": {
5298 "op": Op.CONV2D,
5299 "operands": (1, 2),
5300 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01005301 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005302 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07005303 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01005304 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005305 "template": True,
5306 },
Kevin Cheng1533b852021-09-01 12:51:58 -07005307 # Templated operator. Filled in by createDynamicOpLists
5308 "conv3d_TEMPLATE": {
5309 "op": Op.CONV3D,
5310 "operands": (1, 2),
5311 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01005312 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07005313 "qgen": TosaQuantGen.qgConv,
5314 "types": TYPE_CONV,
5315 "template": True,
5316 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005317 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08005318 "depthwise_conv2d_TEMPLATE": {
5319 "op": Op.DEPTHWISE_CONV2D,
5320 "operands": (1, 2),
5321 "filter": [1, 1],
5322 "rank": (4, 4),
5323 "build_fcn": (
5324 build_depthwise_conv2d,
5325 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01005326 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005327 ),
5328 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07005329 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01005330 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005331 "template": True,
5332 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005333 "fully_connected": {
5334 "op": Op.FULLY_CONNECTED,
5335 "operands": (1, 2),
5336 "rank": (2, 2),
5337 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
5338 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07005339 "types": TYPE_CONV,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005340 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evWrongRank,
5341 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005342 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005343 "matmul": {
5344 "op": Op.MATMUL,
5345 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07005346 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08005347 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
5348 "qgen": TosaQuantGen.qgMatmul,
5349 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005350 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
5351 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005352 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005353 "max_pool2d": {
5354 "op": Op.MAX_POOL2D,
5355 "operands": (1, 0),
5356 "rank": (4, 4),
5357 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
5358 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005359 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
5360 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
5361 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5362 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005363 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005364 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08005365 "transpose_conv2d_TEMPLATE": {
5366 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07005367 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005368 "rank": (4, 4),
5369 "build_fcn": (
5370 build_transpose_conv2d,
5371 TosaTensorGen.tgTransposeConv2D,
5372 TosaArgGen.agTransposeConv2D,
5373 ),
5374 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07005375 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01005376 "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005377 "template": True,
5378 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005379 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08005380 "clamp": {
5381 "op": Op.CLAMP,
5382 "operands": (1, 0),
5383 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
5384 "types": TYPE_NARROW_INT_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005385 "error_if_validators": (TosaErrorValidator.evMaxSmallerMin, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5386 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005387 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08005388 "sigmoid": {
5389 "op": Op.SIGMOID,
5390 "operands": (1, 0),
5391 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
5392 "types": TYPE_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005393 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5394 TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005395 },
5396 "tanh": {
5397 "op": Op.TANH,
5398 "operands": (1, 0),
5399 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
5400 "types": TYPE_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005401 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5402 TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005403 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005404 # Elementwise Binary Operators
5405 "add": {
5406 "op": Op.ADD,
5407 "operands": (2, 0),
5408 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5409 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005410 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005411 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005412 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005413 "arithmetic_right_shift": {
5414 "op": Op.ARITHMETIC_RIGHT_SHIFT,
5415 "operands": (2, 0),
5416 "build_fcn": (
5417 build_arithmetic_right_shift,
5418 TosaTensorGen.tgBroadcastFuzz,
5419 TosaArgGen.agArithmeticRightShift,
5420 ),
5421 "types": TYPE_INT,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005422 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5423 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005424 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005425 "bitwise_and": {
5426 "op": Op.BITWISE_AND,
5427 "operands": (2, 0),
5428 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5429 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005430 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005431 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005432 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005433 "bitwise_or": {
5434 "op": Op.BITWISE_OR,
5435 "operands": (2, 0),
5436 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5437 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005438 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005439 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005440 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005441 "bitwise_xor": {
5442 "op": Op.BITWISE_XOR,
5443 "operands": (2, 0),
5444 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5445 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005446 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005447 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005448 },
Matthew Haddon459443c2021-08-23 16:43:13 +01005449 "intdiv": {
5450 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005451 "operands": (2, 0),
5452 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5453 "types": [DType.INT32],
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005454 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005455 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005456 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005457 "logical_and": {
5458 "op": Op.LOGICAL_AND,
5459 "operands": (2, 0),
5460 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5461 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005462 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005463 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005464 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005465 "logical_left_shift": {
5466 "op": Op.LOGICAL_LEFT_SHIFT,
5467 "operands": (2, 0),
5468 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5469 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005470 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005471 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005472 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005473 "logical_right_shift": {
5474 "op": Op.LOGICAL_RIGHT_SHIFT,
5475 "operands": (2, 0),
5476 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5477 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005478 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005479 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005480 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005481 "logical_or": {
5482 "op": Op.LOGICAL_OR,
5483 "operands": (2, 0),
5484 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5485 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005486 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005487 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005488 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005489 "logical_xor": {
5490 "op": Op.LOGICAL_XOR,
5491 "operands": (2, 0),
5492 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5493 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005494 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005495 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005496 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005497 "maximum": {
5498 "op": Op.MAXIMUM,
5499 "operands": (2, 0),
5500 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5501 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005502 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005503 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005504 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005505 "minimum": {
5506 "op": Op.MINIMUM,
5507 "operands": (2, 0),
5508 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5509 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005510 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005511 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005512 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005513 "mul": {
5514 "op": Op.MUL,
5515 "operands": (2, 0),
5516 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
5517 "types": TYPE_INT_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005518 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005519 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evRankMismatch, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005520 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005521 "pow": {
5522 "op": Op.POW,
5523 "operands": (2, 0),
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005524 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08005525 "types": TYPE_FP,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005526 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005527 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005528 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005529 "sub": {
5530 "op": Op.SUB,
5531 "operands": (2, 0),
5532 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
5533 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005534 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005535 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005536 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005537 "table": {
5538 "op": Op.TABLE,
5539 # Use the automatic generation functions to create the input array
5540 # but create the table tensor in the build function, as it may be
5541 # a different type from the input
5542 "operands": (1, 0),
Kevin Chengfe392ce2021-10-18 21:51:55 +00005543 "build_fcn": (build_table, TosaTensorGen.tgBasic, TosaArgGen.agTable),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005544 "types": [DType.INT8, DType.INT16],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005545 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5546 TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005547 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005548 # Elementwise Unary operators
5549 "abs": {
5550 "op": Op.ABS,
5551 "operands": (1, 0),
5552 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5553 "types": TYPE_FI32,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005554 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5555 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005556 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005557 "bitwise_not": {
5558 "op": Op.BITWISE_NOT,
5559 "operands": (1, 0),
5560 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5561 "types": TYPE_INT,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005562 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5563 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005564 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005565 "ceil": {
5566 "op": Op.CEIL,
5567 "operands": (1, 0),
5568 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5569 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005570 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5571 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005572 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005573 "clz": {
5574 "op": Op.CLZ,
5575 "operands": (1, 0),
5576 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5577 "types": [DType.INT32],
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005578 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5579 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005580 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005581 "exp": {
5582 "op": Op.EXP,
5583 "operands": (1, 0),
5584 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5585 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005586 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5587 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005588 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005589 "floor": {
5590 "op": Op.FLOOR,
5591 "operands": (1, 0),
5592 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5593 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005594 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5595 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005596 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005597 "log": {
5598 "op": Op.LOG,
5599 "operands": (1, 0),
5600 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5601 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005602 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5603 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005604 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005605 "logical_not": {
5606 "op": Op.LOGICAL_NOT,
5607 "operands": (1, 0),
5608 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5609 "types": TYPE_BOOL,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005610 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5611 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005612 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005613 "negate": {
5614 "op": Op.NEGATE,
5615 "operands": (1, 0),
5616 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5617 "qgen": TosaQuantGen.qgUnary,
5618 "types": TYPE_INT_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005619 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
5620 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
5621 TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005622 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005623 "reciprocal": {
5624 "op": Op.RECIPROCAL,
5625 "operands": (1, 0),
5626 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5627 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005628 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5629 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005630 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005631 "rsqrt": {
5632 "op": Op.RSQRT,
5633 "operands": (1, 0),
5634 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5635 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005636 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5637 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005638 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005639 # Elementwise Ternary operators
5640 "select": {
5641 "op": Op.SELECT,
5642 "operands": (3, 0),
5643 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
5644 "types": TYPE_FIB,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005645 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5646 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005647 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005648 # Comparison operators
5649 "equal": {
5650 "op": Op.EQUAL,
5651 "operands": (2, 0),
5652 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
5653 "types": TYPE_FI32,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005654 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5655 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005656 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005657 "greater_equal": {
5658 "op": Op.GREATER_EQUAL,
5659 "operands": (2, 0),
5660 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
5661 "types": TYPE_FI32,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005662 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5663 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005664 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005665 "greater": {
5666 "op": Op.GREATER,
5667 "operands": (2, 0),
5668 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
5669 "types": TYPE_FI32,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005670 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5671 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08005672 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005673 # Reduction operators
5674 "reduce_all": {
5675 "op": Op.REDUCE_ALL,
5676 "operands": (1, 0),
5677 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
5678 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01005679 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
5680 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
5681 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005682 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005683 "reduce_any": {
5684 "op": Op.REDUCE_ANY,
5685 "operands": (1, 0),
5686 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
5687 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01005688 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
5689 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
5690 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005691 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005692 "reduce_max": {
5693 "op": Op.REDUCE_MAX,
5694 "operands": (1, 0),
5695 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
5696 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01005697 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
5698 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
5699 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005700 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005701 "reduce_min": {
5702 "op": Op.REDUCE_MAX,
5703 "operands": (1, 0),
5704 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
5705 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01005706 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
5707 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
5708 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005709 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005710 "reduce_product": {
5711 "op": Op.REDUCE_PRODUCT,
5712 "operands": (1, 0),
5713 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
5714 "types": TYPE_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01005715 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
5716 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
5717 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005718 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005719 "reduce_sum": {
5720 "op": Op.REDUCE_SUM,
5721 "operands": (1, 0),
5722 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
5723 "types": TYPE_FI32,
Matthew Haddond6ce7252021-09-29 15:35:44 +01005724 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
5725 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
5726 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08005727 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005728 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08005729 "concat": {
5730 "op": Op.CONCAT,
5731 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01005732 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005733 "types": TYPE_FIB,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005734 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evConcatInputRankMismatch,
Matthew Haddon01c359d2021-10-15 16:30:48 +01005735 TosaErrorValidator.evConcatShapeSumMismatch, TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType,
5736 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005737 },
5738 "pad": {
5739 "op": Op.PAD,
5740 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01005741 "rank": (1, 5),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005742 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
5743 "qgen": TosaQuantGen.qgPad,
5744 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01005745 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evPadSmallerZero,
5746 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005747 },
5748 "reshape": {
5749 "op": Op.RESHAPE,
5750 "operands": (1, 0),
5751 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
5752 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01005753 "error_if_validators": (TosaErrorValidator.evTensorSizeInputOutputMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5754 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005755 },
5756 "reverse": {
5757 "op": Op.REVERSE,
5758 "operands": (1, 0),
5759 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
5760 "types": TYPE_FIB,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005761 "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evWrongInputType,
5762 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005763 },
5764 "slice": {
5765 "op": Op.SLICE,
5766 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01005767 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005768 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
5769 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01005770 "error_if_validators": (TosaErrorValidator.evStartSmallerZero, TosaErrorValidator.evSizeSmallerEqualZero, TosaErrorValidator.evStartSizeOutsideBounds,
5771 TosaErrorValidator.evSizeOutputShapeMismatch, TosaErrorValidator.evInputSizeStartLengthMismatch, TosaErrorValidator.evWrongRank,
5772 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005773 },
5774 "tile": {
5775 "op": Op.TILE,
5776 "operands": (1, 0),
5777 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
5778 "types": TYPE_FIB,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005779 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5780 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005781 },
5782 "transpose": {
5783 "op": Op.TRANSPOSE,
5784 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01005785 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005786 "build_fcn": (
5787 build_transpose,
5788 TosaTensorGen.tgBasic,
5789 TosaArgGen.agTranspose,
5790 ),
5791 "types": TYPE_FIB,
Matthew Haddone807aae2021-10-11 18:12:58 +01005792 "error_if_validators": (TosaErrorValidator.evIndexOutsideBounds, TosaErrorValidator.evIndexUsedTwice, TosaErrorValidator.evWrongRank,
5793 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005794 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005795 # Data nodes
5796 "const": {
5797 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07005798 "operands": (0, 1),
5799 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08005800 "types": TYPE_FIB,
5801 },
Jared Smolens573ecd42021-03-04 15:24:10 -08005802 "identity": {
5803 "op": Op.IDENTITY,
5804 "operands": (1, 0),
5805 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
5806 "types": TYPE_FIB,
5807 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005808 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08005809 "gather": {
5810 "op": Op.GATHER,
5811 # Only specify 'values' tensor here. 'indices' is generated in op building stage
5812 "operands": (1, 0),
5813 "rank": (3, 3),
5814 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
5815 "types": TYPE_INT_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005816 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5817 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005818 },
5819 "scatter": {
5820 "op": Op.SCATTER,
5821 # Only specify 'values_in' tensor here.
5822 #'indices' and 'input' are generated in op building stage
5823 "operands": (2, 0),
5824 "rank": (3, 3),
5825 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
5826 "types": TYPE_INT_FP,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005827 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5828 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005829 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005830 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08005831 "resize": {
5832 "op": Op.RESIZE,
5833 "operands": (1, 0),
5834 "rank": (4, 4),
5835 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
5836 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01005837 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
5838 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
5839 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
Matthew Haddon848efb42021-09-09 12:30:53 +01005840 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005841 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
5842 TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005843 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005844 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08005845 "cast": {
5846 "op": Op.CAST,
5847 "operands": (1, 0),
5848 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
5849 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005850 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
5851 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005852 },
5853 "rescale": {
5854 "op": Op.RESCALE,
5855 "operands": (1, 0),
Matthew Haddonc2025212021-10-08 21:21:05 +01005856 "rank": (1,4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08005857 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01005858 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Matthew Haddonc2025212021-10-08 21:21:05 +01005859 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, TosaErrorValidator.evScaleTrue,
5860 TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
5861 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005862 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005863 # Custom
5864 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08005865 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07005866 # Two varients of cond_if, one that generates one of two constant tensors (no
5867 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
5868 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005869 "cond_if_const": {
5870 "op": Op.COND_IF,
5871 "operands": (0, 2),
5872 "build_fcn": (
5873 build_cond_if_const,
5874 TosaTensorGen.tgBasic,
5875 TosaArgGen.agCondIf,
5876 ),
5877 "types": [DType.BOOL],
Matthew Haddon630c17c2021-10-14 15:05:41 +01005878 "error_if_validators": (TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005879 },
5880 "cond_if_binary": {
5881 "op": Op.COND_IF,
5882 "operands": (2, 0),
5883 "build_fcn": (
5884 build_cond_if_binary,
5885 TosaTensorGen.tgBasic,
5886 TosaArgGen.agCondIf,
5887 ),
Les Bell6040b4d2021-10-11 12:50:31 +01005888 "types": TYPE_INT_FP,
Matthew Haddon630c17c2021-10-14 15:05:41 +01005889 "error_if_validators": (TosaErrorValidator.evInputListThenGraphMismatch, TosaErrorValidator.evInputListElseGraphMismatch,
5890 TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005891 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005892 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08005893 "while_loop": {
5894 "op": Op.WHILE_LOOP,
5895 "operands": (0, 1),
5896 "build_fcn": (
5897 build_while_loop,
5898 TosaTensorGen.tgBasic,
5899 TosaArgGen.agWhileLoop,
5900 ),
5901 "types": [DType.INT32],
Matthew Haddon630c17c2021-10-14 15:05:41 +01005902 "error_if_validators": (TosaErrorValidator.evInputListOutputListMismatch, TosaErrorValidator.evInputListCondGraphMismatch,
5903 TosaErrorValidator.evInputListBodyGraphInputMismatch, TosaErrorValidator.evInputListBodyGraphOutputMismatch,
5904 TosaErrorValidator.evCondGraphOutputNotMatchingBool)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005905 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005906 }
5907
Kevin Cheng550ccc52021-03-03 11:21:43 -08005908
Eric Kunzee5e26762020-10-13 16:11:07 -07005909class OutputShaper:
5910 # Methods in this class compute the expected output shape and datatype
5911 # for common classes of operations
5912 def __init__(self):
5913 pass
5914
5915 # These methods return arguments that can be used for
5916 # creating a new output tensor
5917 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005918 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
5919 if error_name != ErrorIf.RankMismatch:
5920 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005921 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005922
5923 shape = []
5924 for i in range(len(a.shape)):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005925 if a.shape[i] == 1 and error_name == None:
Eric Kunzee5e26762020-10-13 16:11:07 -07005926 shape.append(b.shape[i])
5927 else:
5928 shape.append(a.shape[i])
5929
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005930 if error_name == ErrorIf.WrongOutputType:
5931 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5932 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5933 outputDType = rng.choice(wrong_dtypes)
5934 else:
5935 outputDType = a.dtype
5936
5937 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005938
5939 @staticmethod
5940 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005941 assert len(a.shape) == len(b.shape)
5942 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005943
5944 shape = []
5945 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005946 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07005947 shape.append(a.shape[i])
5948
Kevin Cheng550ccc52021-03-03 11:21:43 -08005949 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005950
5951 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005952 def unaryOp(ser, rng, a, error_name=None):
5953 if error_name == ErrorIf.WrongOutputType:
5954 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5955 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5956 outputDType = rng.choice(wrong_dtypes)
5957 else:
5958 outputDType = a.dtype
5959
5960 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005961
5962 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005963 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005964 if error_name != ErrorIf.RankMismatch:
5965 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005966 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005967
5968 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005969 for i in range(len(cond.shape)):
5970 if cond.shape[i] == 1 and error_name == None:
5971 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5972 else:
5973 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07005974
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005975 if error_name == ErrorIf.WrongOutputType:
5976 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
5977 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5978 outputDType = rng.choice(wrong_dtypes)
5979 else:
5980 outputDType = a.dtype
5981
5982 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005983
5984 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005985 def binaryComparisonOp(ser, rng, a, b , error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005986 if error_name != ErrorIf.RankMismatch:
5987 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005988 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005989
5990 # Do broadcast
5991 shape = []
5992 for i in range(len(a.shape)):
5993 if a.shape[i] == 1:
5994 shape.append(b.shape[i])
5995 else:
5996 shape.append(a.shape[i])
5997
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005998 if error_name == ErrorIf.WrongOutputType:
5999 wrong_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6000 outputDType = rng.choice(wrong_dtypes)
6001 else:
6002 outputDType = DType.BOOL
6003
6004 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006005
6006 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01006007 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006008 shape = a.shape.copy()
Matthew Haddond6ce7252021-09-29 15:35:44 +01006009 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ShapeOfAxisNotOne]:
6010 shape[axis] = 1
6011 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
6012 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07006013
Matthew Haddond6ce7252021-09-29 15:35:44 +01006014 if error_name == ErrorIf.WrongOutputType:
6015 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6016 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6017 outputDType = rng.choice(wrong_dtypes)
6018 else:
6019 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07006020
Matthew Haddond6ce7252021-09-29 15:35:44 +01006021 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006022
6023 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006024 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006025 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006026
6027 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
6028 del shape[axis]
6029
6030 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
6031 remove = rng.choice([True, False])
6032 if remove and len(shape) > 1:
6033 del shape[0]
6034 else:
6035 shape.append(1)
6036 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
6037 for i in range(len(shape)):
6038 shape[i] = shape[i] + rng.integers(1, 10)
6039
6040 if error_name == ErrorIf.WrongOutputType:
6041 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6042 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
6043 outputDType = rng.choice(wrong_dtypes)
6044 else:
6045 outputDType = DType.INT32
6046
6047 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006048
6049 @staticmethod
6050 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
6051
6052 # IFM: NHWC
6053 # Filter: OHWI
6054 # OFM: NHWC
6055
6056 if len(padding) == 2:
6057 # Expand padding to 4 parameters in the case of transpose_conv2d
6058 # From H,W to T,B,L,R
6059 padding = [padding[0], padding[0], padding[1], padding[1]]
6060
Kevin Cheng550ccc52021-03-03 11:21:43 -08006061 h = (
6062 ifm.shape[1]
6063 - filter.shape[1]
6064 - (filter.shape[1] - 1) * (dilations[0] - 1)
6065 + padding[0]
6066 + padding[1]
6067 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07006068
Kevin Cheng550ccc52021-03-03 11:21:43 -08006069 w = (
6070 ifm.shape[2]
6071 - filter.shape[2]
6072 - (filter.shape[2] - 1) * (dilations[1] - 1)
6073 + padding[2]
6074 + padding[3]
6075 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07006076
Eric Kunzee5e26762020-10-13 16:11:07 -07006077 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
6078
Kevin Cheng3a478572021-01-22 17:21:02 -08006079 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006080 out_dtype = DType.INT32
6081 elif ifm.dtype == DType.INT16:
6082 out_dtype = DType.INT48
6083 elif ifm.dtype == DType.FLOAT:
6084 out_dtype = DType.FLOAT
6085 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08006086 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07006087
Kevin Cheng550ccc52021-03-03 11:21:43 -08006088 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006089
6090 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -07006091 def conv3dOp(ser, ifm, filter, strides, padding, dilations):
6092
6093 # IFM: NDHWC
6094 # Filter: ODHWI
6095 # OFM: NDHWC
6096
6097 d = (
6098 ifm.shape[1]
6099 - filter.shape[1]
6100 - (filter.shape[1] - 1) * (dilations[0] - 1)
6101 + padding[0]
6102 + padding[1]
6103 ) // strides[0] + 1
6104
6105 h = (
6106 ifm.shape[2]
6107 - filter.shape[2]
6108 - (filter.shape[2] - 1) * (dilations[1] - 1)
6109 + padding[2]
6110 + padding[3]
6111 ) // strides[1] + 1
6112
6113 w = (
6114 ifm.shape[3]
6115 - filter.shape[3]
6116 - (filter.shape[3] - 1) * (dilations[2] - 1)
6117 + padding[4]
6118 + padding[5]
6119 ) // strides[2] + 1
6120
6121 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
6122
6123 if ifm.dtype == DType.INT8:
6124 out_dtype = DType.INT32
6125 elif ifm.dtype == DType.INT16:
6126 out_dtype = DType.INT48
6127 elif ifm.dtype == DType.FLOAT:
6128 out_dtype = DType.FLOAT
6129 else:
6130 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
6131
6132 return ser.addOutput(ofm_shape, out_dtype)
6133
6134 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -07006135 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
6136 # IFM: NHWC
6137 # Filter: HWCM
6138 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08006139 h = (
6140 ifm.shape[1]
6141 - filter.shape[0]
6142 - (filter.shape[0] - 1) * (dilations[0] - 1)
6143 + padding[0]
6144 + padding[1]
6145 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07006146
Kevin Cheng550ccc52021-03-03 11:21:43 -08006147 w = (
6148 ifm.shape[2]
6149 - filter.shape[1]
6150 - (filter.shape[1] - 1) * (dilations[1] - 1)
6151 + padding[2]
6152 + padding[3]
6153 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07006154
Eric Kunzee5e26762020-10-13 16:11:07 -07006155 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
6156
Kevin Cheng3a478572021-01-22 17:21:02 -08006157 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006158 out_dtype = DType.INT32
6159 elif ifm.dtype == DType.INT16:
6160 out_dtype = DType.INT48
6161 elif ifm.dtype == DType.FLOAT:
6162 out_dtype = DType.FLOAT
6163 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08006164 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07006165
Kevin Cheng550ccc52021-03-03 11:21:43 -08006166 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006167
6168 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01006169 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006170 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01006171 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006172 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01006173 h = 1
6174 w = 1
6175 else:
6176 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
6177 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
6178
6179 if error_name == ErrorIf.PoolingOutputShapeMismatch:
6180 choices = [1, 2, 3, 4, 5]
6181 h = h + rng.choice(choices)
6182 w = w + rng.choice(choices)
Eric Kunzee5e26762020-10-13 16:11:07 -07006183
Eric Kunzee5e26762020-10-13 16:11:07 -07006184 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01006185
6186 if error_name == ErrorIf.WrongOutputType:
6187 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6188 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
6189 outputDType = rng.choice(wrong_dtypes)
6190 else:
6191 outputDType = ifm.dtype
6192
6193 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006194
6195 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006196 def fullyConnectedOp(ser, rng, input, filter, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006197 # input: N, IC
6198 # filter: OC, IC
6199 # output: N, OC
6200
6201 output_shape = [input.shape[0], filter.shape[0]]
6202
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006203 if error_name == ErrorIf.WrongOutputType:
6204 if input.dtype == DType.INT8:
6205 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
6206 elif input.dtype == DType.INT16:
6207 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
6208 elif input.dtype == DType.FLOAT:
6209 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
6210 out_dtype = rng.choice(a=incorrect_types)
6211 elif input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006212 out_dtype = DType.INT32
6213 elif input.dtype == DType.INT16:
6214 out_dtype = DType.INT48
6215 elif input.dtype == DType.FLOAT:
6216 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006217 elif error_name == ErrorIf.WrongInputType:
6218 # Pick some potentially correct output dtype if input type is incorrect
6219 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006220 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08006221 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07006222
Kevin Cheng550ccc52021-03-03 11:21:43 -08006223 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006224
6225 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006226 def matmulOp(ser, rng, a, b, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07006227 # a: N, H, C
6228 # b: N, C, W
6229 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07006230
Kevin Cheng2d60f002021-06-09 14:18:32 -07006231 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07006232
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006233 if error_name == ErrorIf.WrongOutputType:
6234 if a.dtype == DType.INT8:
6235 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
6236 elif a.dtype == DType.INT16:
6237 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
6238 elif a.dtype == DType.FLOAT:
6239 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
6240 out_dtype = rng.choice(a=incorrect_types)
6241 elif a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006242 out_dtype = DType.INT32
6243 elif a.dtype == DType.INT16:
6244 out_dtype = DType.INT48
6245 elif a.dtype == DType.FLOAT:
6246 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006247 elif error_name == ErrorIf.WrongInputType:
6248 # Pick some potentially correct output dtype if input type is incorrect
6249 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006250 else:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006251 raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07006252
Kevin Cheng550ccc52021-03-03 11:21:43 -08006253 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006254
6255 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006256 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01006257 input1 = a[0]
6258 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07006259
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006260 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01006261 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006262 if not (
6263 # unable to concat tensors of different ranks
6264 error_name == ErrorIf.ConcatInputRankMismatch
6265 # unable to concat tensors along an invalid axis
6266 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006267 ):
6268 for tensor in remaining_inputs:
6269 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07006270
Matthew Haddon01c359d2021-10-15 16:30:48 +01006271 if error_name == ErrorIf.ConcatShapeSumMismatch:
6272 output_shape[axis] += rng.integers(5, 10)
6273
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006274 if error_name == ErrorIf.WrongOutputType:
6275 all_dtypes = {DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
6276 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
6277 outputDType = rng.choice(wrong_dtypes)
6278 else:
6279 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01006280
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006281 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006282
6283 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01006284 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006285
6286 output_shape = a.shape.copy()
6287
6288 for i in range(len(output_shape)):
6289 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
6290
Matthew Haddone807aae2021-10-11 18:12:58 +01006291 # Fix negative output shape if error_if test causes it
6292 if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
6293 output_shape = [i if i >= 1 else 1 for i in output_shape]
6294
6295 if error_name == ErrorIf.WrongOutputType:
6296 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6297 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6298 outputDType = rng.choice(wrong_dtypes)
6299 else:
6300 outputDType = a.dtype
6301
6302 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006303
6304 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01006305 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006306 output_shape = shape.copy()
6307
6308 totalElements = 1
6309 for i in a.shape:
6310 totalElements *= i
6311
6312 # If there are any -1 elements, figure out what that dimension must be
6313 totalOutputElements = 1
6314 for i in output_shape:
6315 if i != -1:
6316 totalOutputElements *= i
6317
6318 # And fill it in
6319 for i in range(len(output_shape)):
6320 if output_shape[i] == -1:
6321 output_shape[i] = totalElements // totalOutputElements
6322
Matthew Haddone807aae2021-10-11 18:12:58 +01006323 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
6324 for i in range(len(output_shape)):
6325 output_shape[i] = output_shape[i] + rng.integers(1, 10)
6326
6327 if error_name == ErrorIf.WrongOutputType:
6328 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6329 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6330 outputDType = rng.choice(wrong_dtypes)
6331 else:
6332 outputDType = a.dtype
6333
6334 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006335
6336 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01006337 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006338
Matthew Haddone807aae2021-10-11 18:12:58 +01006339 if error_name == ErrorIf.WrongOutputType:
6340 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6341 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6342 outputDType = rng.choice(wrong_dtypes)
6343 else:
6344 outputDType = a.dtype
6345
6346 if error_name == ErrorIf.SizeOutputShapeMismatch:
6347 output_shape = size.copy()
6348 for index in range(len(output_shape)):
6349 if output_shape[index] <= 2:
6350 output_shape[index] = output_shape[index] + rng.choice([1, 2])
6351 else:
6352 output_shape[index] = output_shape[index] + rng.choice([-2, -1, 1, 2])
6353 else:
6354 output_shape = size.copy()
6355
6356 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006357
6358 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006359 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006360
6361 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08006362 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07006363
6364 for i in range(len(output_shape)):
6365 output_shape[i] = a.shape[i] * multiples[i]
6366
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006367 if error_name == ErrorIf.WrongOutputType:
6368 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6369 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6370 outputDType = rng.choice(wrong_dtypes)
6371 else:
6372 outputDType = a.dtype
6373
6374 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006375
6376 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01006377 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07006378 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01006379
Kevin Cheng550ccc52021-03-03 11:21:43 -08006380 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07006381
Matthew Haddone807aae2021-10-11 18:12:58 +01006382 if error_name == ErrorIf.IndexOutsideBounds:
6383 for i in range(len(output_shape)):
6384 output_shape[i] = a.shape[0]
6385 else:
6386 for i in range(len(output_shape)):
6387 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07006388
Matthew Haddone807aae2021-10-11 18:12:58 +01006389 if error_name == ErrorIf.WrongOutputType:
6390 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6391 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
6392 outputDType = rng.choice(wrong_dtypes)
6393 else:
6394 outputDType = a.dtype
6395
6396 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006397
6398 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006399 def gatherOp(ser, rng, values, indices, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08006400 assert len(values.shape) == 3
6401 assert len(indices.shape) == 2
6402 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07006403
Kevin Cheng77d0f762020-11-24 10:26:32 -08006404 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
6405
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006406 if error_name == ErrorIf.WrongOutputType:
6407 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6408 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
6409 outputDType = rng.choice(wrong_dtypes)
6410 else:
6411 outputDType = values.dtype
6412
6413 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08006414
6415 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006416 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08006417 assert len(values_in.shape) == 3
6418 assert len(indices.shape) == 2
6419 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08006420 assert values_in.shape[0] == indices.shape[0] # N
6421 assert input.shape[1] == indices.shape[1] # W
6422 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08006423
6424 output_shape = values_in.shape
6425
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006426 if error_name == ErrorIf.WrongOutputType:
6427 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6428 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
6429 outputDType = rng.choice(wrong_dtypes)
6430 else:
6431 outputDType = values_in.dtype
6432
6433 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07006434
6435 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006436 def tableOp(ser, rng, input, error_name=None):
6437 # Same shape as the input, dtype dependent on input dtype
6438 if error_name != ErrorIf.WrongInputType:
6439 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00006440 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006441 if error_name == ErrorIf.WrongOutputType:
6442 wrong_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
6443 wrong_dtypes.remove(output_dtype)
6444 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01006445 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006446
6447 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08006448 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01006449 serializer,
6450 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08006451 input,
6452 mode,
6453 stride,
6454 offset,
6455 shift,
6456 stride_fp,
6457 offset_fp,
6458 output_dims,
6459 input_dtype,
6460 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01006461 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08006462 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01006463 if error_name == ErrorIf.WrongRank:
6464 output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
6465 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01006466 if error_name == ErrorIf.BatchMismatch:
6467 output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
6468 elif error_name == ErrorIf.ChannelMismatch:
6469 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
6470 else:
6471 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07006472
Matthew Haddon693ba9e2021-09-22 11:24:37 +01006473 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006474
6475 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01006476 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08006477 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07006478
6479 @staticmethod
6480 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08006481 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07006482 out_dtype = DType.INT32
6483 elif ifm.dtype == DType.INT16:
6484 out_dtype = DType.INT48
6485 elif ifm.dtype == DType.FLOAT:
6486 out_dtype = DType.FLOAT
6487 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08006488 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07006489
Kevin Cheng550ccc52021-03-03 11:21:43 -08006490 return ser.addOutput(output_shape, out_dtype)