blob: 6780aa7384453a431179addb0192ea1e06e32406 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
48DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080049Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070050ResizeMode = tosa.ResizeMode.ResizeMode()
51
52class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
54
Eric Kunzee5e26762020-10-13 16:11:07 -070055 def __init__(self):
56 pass
57
58 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010059 def getQinfo(testGen, dtype, error_name=None):
60
Les Bell30e46802021-07-23 09:43:31 +010061 if dtype == DType.INT8:
62 return testGen.randInt(-128, 128)
Matthew Haddone4ecdb22021-09-28 11:38:21 +010063 elif dtype == DType.UINT8:
Les Bell30e46802021-07-23 09:43:31 +010064 return testGen.randInt(0, 256)
Matthew Haddone4ecdb22021-09-28 11:38:21 +010065 elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
66 zero_point = testGen.randInt(-128, 128)
67 if zero_point == 0:
68 zero_point = 1
69 return zero_point
Les Bell30e46802021-07-23 09:43:31 +010070 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070071
72 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010073 def qgUnary(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070074 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone4ecdb22021-09-28 11:38:21 +010075 if error_name == ErrorIf.InputZeroPointNotZero:
76 qinfo.UnaryQuantInfo(
77 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype)
78 )
79 elif error_name == ErrorIf.OutputZeroPointNotZero:
80 qinfo.UnaryQuantInfo(
81 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype, error_name)
82 )
83 else:
84 qinfo.UnaryQuantInfo(
85 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
86 )
Eric Kunzee5e26762020-10-13 16:11:07 -070087 return qinfo
88
89 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010090 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070091 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010092 if isinstance(dtype_or_dtypeList, list):
93 # a list of [input, weights, accumulator] dtypes
94 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070095 else:
Les Bell30e46802021-07-23 09:43:31 +010096 # an int, [input, weights, accumulator] dtypes are the same
97 dtypeList = [dtype_or_dtypeList] * 3
98 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
99 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
100 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -0700101 return qinfo
102
103 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100104 def qgMatmul(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700105 qinfo = ts.TosaSerializerQuantInfo()
Kevin Chengacb550f2021-06-29 15:32:19 -0700106 qinfo.MatMulQuantInfo(
107 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
108 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700109 return qinfo
110
111 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100112 def qgPad(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700113 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +0100114 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700115 return qinfo
116
117 @staticmethod
118 def computeMultiplierAndShift(scaleFp, scale32):
119 # Derived from computeMultiplierAndShiftTosaScale32
120 # Provide a floating-point scaling factor and the scale32 parameter
121 # to compute the multiplier and shift
122
123 if scale32:
124 scaleBits = 31
125 else:
126 scaleBits = 15
127
128 m, shift = math.frexp(scaleFp)
129
130 if scaleFp < 0.0:
131 m = -m
132
133 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800134 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700135
136 if multiplier == (1 << scaleBits):
137 multiplier = multiplier // 2
138 shift = shift + 1
139
140 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100141 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
142
143 # Adjust multiplier such that shift is in allowed value range.
144 if shift == 0:
145 multiplier = multiplier // 4
146 shift = shift + 2
147 elif shift == 1:
148 multiplier = multiplier // 2
149 shift = shift + 1
150 elif shift == 63:
151 multiplier = multiplier * 2
152 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700153
Kevin Cheng550ccc52021-03-03 11:21:43 -0800154 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100155 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700156
157 return multiplier, shift
158
159
Kevin Cheng550ccc52021-03-03 11:21:43 -0800160class TosaTensorGen:
161 """Tensor generators create a shape list for the placeholder and const tensor
162 data operands for the operator. The actual random data is generated separately for each test."""
163
Eric Kunzee5e26762020-10-13 16:11:07 -0700164 def __init__(self):
165 pass
166
167 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100168 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800169 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700170 shape = testGen.makeShape(rank)
171
Matthew Haddonc2025212021-10-08 21:21:05 +0100172 # Constrict dimension size for large ranks when creating WrongRank tests
173 shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
174
Eric Kunzee5e26762020-10-13 16:11:07 -0700175 shape_list = []
176 for i in range(pl + const):
177 shape_list.append(shape.copy())
178
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100179 if error_name == ErrorIf.RankMismatch:
180 if rank == 1 and i != 1:
181 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
182 elif i != 1:
183 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
184
Eric Kunzee5e26762020-10-13 16:11:07 -0700185 return shape_list
186
187 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100188 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800189 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700190
Matthew Haddon848efb42021-09-09 12:30:53 +0100191 if error_name != ErrorIf.WrongRank:
192 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700193
194 shape = testGen.makeShape(rank)
195
196 # Constrict the batch size?
197 if testGen.args.max_batch_size:
198 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100199 # Constrict dimension size for large ranks
200 if rank > 4:
201 shape[4] = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700202
203 shape_list = []
204 for i in range(pl + const):
205 shape_list.append(shape.copy())
206
207 return shape_list
208
209 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100210 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800211 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800212
Kevin Cheng550ccc52021-03-03 11:21:43 -0800213 assert pl == 2
214 assert const == 0
215 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800216
217 values_in_shape = testGen.makeShape(rank)
218
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100219 # ignore max batch size if target shape is set
220 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800221 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
222
Kevin Cheng550ccc52021-03-03 11:21:43 -0800223 W = testGen.randInt(
224 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
225 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100226 # Constrict W if one dimension is too large to keep tensor size reasonable
227 if max(values_in_shape) > 5000:
228 W = testGen.randInt(0, 16)
229
Kevin Cheng77d0f762020-11-24 10:26:32 -0800230 input_shape = [values_in_shape[0], W, values_in_shape[2]]
231
232 shape_list = []
233 shape_list.append(values_in_shape.copy())
234 shape_list.append(input_shape.copy())
235
236 return shape_list
237
238 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100239 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700240 shape = testGen.makeShape(rank)
241
Kevin Cheng550ccc52021-03-03 11:21:43 -0800242 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700243
244 shape_list = []
245
246 # Choose one of the inputs to broadcast
247 bcast_idx = testGen.randInt(0, pl + const)
248 for i in range(pl + const):
249 shape_bcast = shape.copy()
250
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100251 if error_name == ErrorIf.RankMismatch:
252 bcast_idx = -1 # Turn off broadcast because we are not testing it
253 if rank == 1 and i != 1:
254 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
255 elif i != 1:
256 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
257
Eric Kunzee5e26762020-10-13 16:11:07 -0700258 # If the chosen input, pick a random index to broadcast
259 if i == bcast_idx:
260 fuzz_idx = testGen.randInt(0, rank)
261 shape_bcast[fuzz_idx] = 1
262
263 shape_list.append(shape_bcast)
264
265 return shape_list
266
267 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100268 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800269 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700270
Kevin Cheng550ccc52021-03-03 11:21:43 -0800271 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700272
273 # IFM dimensions are NHWC
274 ifm_shape = testGen.makeShape(rank)
275
276 # Constrict the batch size?
277 if testGen.args.max_batch_size:
278 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
279
280 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800281 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700282
283 # Generate a random OFM depth
284 ofm_depth = testGen.makeShape(1)[0]
285
286 # The filter dimensions are OHWI
287 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
288
289 # The bias is OC
290 bias_shape = np.asarray([ofm_depth])
291
292 return [ifm_shape, filter_shape, bias_shape]
293
294 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100295 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700296 pl, const = op["operands"]
297
298 assert rank == 5
299
300 # IFM dimensions are NDHWC
301 ifm_shape = testGen.makeShape(rank)
302
303 # Constrict the batch size?
304 if testGen.args.max_batch_size:
305 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
306
307 # Get the filter depth/height/width from the operator parameters
308 filter_dhw = op["filter"]
309
310 # Generate a random OFM channel
311 ofm_channel = testGen.makeShape(1)[0]
312
313 # The filter dimensions are ODHWI
314 filter_shape = np.asarray(
315 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
316 )
317
318 # The bias is OC
319 bias_shape = np.asarray([ofm_channel])
320
321 return [ifm_shape, filter_shape, bias_shape]
322
323 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100324 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800325 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700326
Kevin Cheng550ccc52021-03-03 11:21:43 -0800327 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700328
329 # IFM dimensions are NHWC
330 ifm_shape = testGen.makeShape(rank)
331
332 # Constrict the batch size?
333 if testGen.args.max_batch_size:
334 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
335
336 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800337 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700338
339 # Generate a random OFM depth
340 ofm_depth = testGen.makeShape(1)[0]
341
342 # The filter dimensions are OHWI
343 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
344
Kevin Cheng989cb052021-04-28 16:29:44 -0700345 # The bias is OC
346 bias_shape = np.asarray([ofm_depth])
347
348 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700349
350 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100351 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800352 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700353
Kevin Cheng550ccc52021-03-03 11:21:43 -0800354 assert rank == 4
355 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700356
357 # IFM dimensions are NHWC
358 ifm_shape = testGen.makeShape(rank)
359
360 # Constrict the batch size?
361 if testGen.args.max_batch_size:
362 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
363
364 # Get the filter height/width from the operator parameters
365 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800366 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700367
368 # Generate a random OFM depth, but don't let it get too big because
369 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800370 filter_m = (
371 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
372 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700373
374 # The filter dimensions are HWCM
375 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
376
377 # The bias is M * C
378 bias_shape = np.asarray([ifm_shape[3] * filter_m])
379
380 return [ifm_shape, filter_shape, bias_shape]
381
382 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100383 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800384 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700385
Kevin Cheng550ccc52021-03-03 11:21:43 -0800386 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700387
388 input_shape = testGen.makeShape(rank)
Kevin Chengacb550f2021-06-29 15:32:19 -0700389 filter_oc = testGen.rng.integers(
390 low=testGen.args.tensor_shape_range[0],
391 high=testGen.args.tensor_shape_range[1],
392 size=1,
393 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700394 filter_shape = np.asarray([filter_oc, input_shape[1]])
395
396 bias_shape = np.asarray([filter_oc])
397
398 return [input_shape, filter_shape, bias_shape]
399
400 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100401 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800402 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700403
Kevin Cheng2d60f002021-06-09 14:18:32 -0700404 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800405 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700406
407 a_shape = testGen.makeShape(rank)
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100408 # Get a random number for b_oc even if target shape is defined
409 b_oc = np.int32(
410 testGen.rng.integers(
411 low=testGen.args.tensor_shape_range[0],
412 high=testGen.args.tensor_shape_range[1],
413 size=1,
414 )
415 )[0]
416 # If N or H is large let b_oc be 1 to reduce output tensor size
417 if max(a_shape) > 1000:
418 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700419
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100420 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700421 return [a_shape, b_shape]
422
Matthew Haddon818ab902021-07-27 09:12:49 +0100423 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100424 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100425 pl, const = opName["operands"]
426 shape = testGen.makeShape(rank)
427
428 # Create extra tensors to concat.
429 # Take into account value of pl when getting maximum number of concats
430 num_tensors = testGen.randInt(0, 4)
431 shape_list = []
432 for i in range(pl + const + num_tensors):
433 shape_list.append(shape.copy())
434
435 return shape_list
436
437 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100438 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100439 # Split concat shape along axis to allow for multiple const inputs
440 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100441 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddon818ab902021-07-27 09:12:49 +0100442 return shapeList
443
Jeremy Johnson960985a2021-10-06 10:58:14 +0100444 # Create copy of shape we are going to split (so we don't alter shapeList)
445 shape = shapeList[0].copy()
446 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100447 new_shapeList = [shape.copy()]
448 length_on_axis = shape[axis]
449 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700450 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100451 # Calculate split on axis and remaining value
452 split_shape_val = int(shape[axis] / 2)
453 remaining_length = remaining_length - split_shape_val
454
455 # Append new shape, and set remaining shape
456 shape[axis] = split_shape_val
457 new_shapeList.append(shape.copy())
458 shape[axis] = remaining_length
459 if i == len(shapeList) - 3:
460 new_shapeList.append(shape.copy())
461
462 return new_shapeList
463
464
Eric Kunzee5e26762020-10-13 16:11:07 -0700465class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800466 """Argument generators create exhaustive or random lists of attributes for operators that take
467 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
468 tuples where the descriptive_name is appended to the test name and the arglist is expanded
469 as arguments to the operator build function."""
470
Eric Kunzee5e26762020-10-13 16:11:07 -0700471 def __init__(self):
472 pass
473
474 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100475 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800476 """A trivial argument generator for operators that don't take any
477 non-tensor arguments"""
478 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700479
480 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100481 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800482 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700483 axes = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700484 shape = shapeList[0]
485
Matthew Haddond6ce7252021-09-29 15:35:44 +0100486 if error_name == ErrorIf.AxisSmallerZero:
487 small_axis = testGen.rng.integers(-5, 0)
488 axes.append(("axis{}".format(small_axis), [small_axis]))
489 elif error_name == ErrorIf.AxisLargerRank:
490 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
491 axes.append(("axis{}".format(large_axis), [large_axis]))
492 else:
493 for a in range(0, len(shape)):
494 axes.append(("axis{}".format(a), [a]))
495
Eric Kunzee5e26762020-10-13 16:11:07 -0700496 return axes
497
498 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100499 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700500 arg_list = []
501
502 ifm_shape = shapeList[0]
503 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100504 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
505 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700506
Les Bell7aa69f42021-09-20 10:44:07 +0100507 # Check the rank
508 rank = 5 if opName.startswith("conv3d") else 4
509 assert len(ifm_shape) == rank
510 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700511
Les Bell7aa69f42021-09-20 10:44:07 +0100512 # kernel rank omits batch and channels
513 k_rank = rank - 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700514
Les Bell7aa69f42021-09-20 10:44:07 +0100515 # Generate comprehensive argument lists
516 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
517 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
518 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
519 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
520 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
521 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700522
Les Bell7aa69f42021-09-20 10:44:07 +0100523 # add some oversize argument values
524 if max(ifm_shape) < 64:
525 bigPadding = 9
526 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
527 bigStride = 8
528 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
529 bigDilation = 7
530 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100531
532 # There are too many parameter combinations, so generate them sparsely
Les Bell7aa69f42021-09-20 10:44:07 +0100533 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
534 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
535 if sparsity < 13:
536 sparsity = 1
537 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
538 sparsity += 1
Les Bellf414b3c2021-09-06 11:29:46 +0100539 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100540 for s in sorted(list(strides)):
541 for p in sorted(list(paddings)):
542 for d in sorted(list(dilations)):
543 if (n % sparsity == 0
544 # padding must not exceed the kernel size ?
545 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
546 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
547 # the padded shape must exceed the kernel size
548 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
549 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
550 # the padded shape must exceed the dilation
551 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
552 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
553 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100554 arg_list.append(
555 (
556 "st{}_pad{}_dilat{}".format(
557 "".join([str(x) for x in s]),
558 "".join([str(x) for x in p]),
559 "".join([str(x) for x in d]),
560 ),
561 [s, p, d],
562 )
563 )
564 n += 1
565
Kevin Cheng1533b852021-09-01 12:51:58 -0700566 return arg_list
567
568 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100569 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700570 arg_list = []
571
572 ifm_shape = shapeList[0]
573 filter_shape = shapeList[1]
574
575 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800576 assert len(ifm_shape) == 4
577 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700578
Les Bell7aa69f42021-09-20 10:44:07 +0100579 # Generate comprehensive argument lists
580 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
581 paddings = {x for x in itertools.product(*([p_vals] * 2))}
582 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
583 strides = {x for x in itertools.product(*([s_vals] * 2))}
584 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
585 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700586
Les Bell7aa69f42021-09-20 10:44:07 +0100587 # add some oversize argument values
588 if max(ifm_shape) < 64:
589 bigPadding = 9
590 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
591 bigStride = 8
592 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
593 bigDilation = 7
594 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700595
Les Bell7aa69f42021-09-20 10:44:07 +0100596 # There are too many parameter combinations, so generate them sparsely
597 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
598 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
599 if sparsity < 13:
600 sparsity = 1
601 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
602 sparsity += 1
603 n = 0
604 for s in sorted(list(strides)):
605 for p in sorted(list(paddings)):
606 for d in sorted(list(dilations)):
607 if n % sparsity == 0:
608 # Determine the output shape
609 oh = (
610 ifm_shape[1]
611 - filter_shape[1]
612 - (filter_shape[1] - 1) * (d[0] - 1)
613 + 2 * p[0]
614 ) // s[0] + 1
615 ow = (
616 ifm_shape[2]
617 - filter_shape[2]
618 - (filter_shape[2] - 1) * (d[1] - 1)
619 + 2 * p[1]
620 ) // s[1] + 1
621 os = [ifm_shape[0], oh, ow, filter_shape[0]]
622 arg_list.append(
623 (
624 "st{}_pad{}_dilat{}_os{}".format(
625 "".join([str(x) for x in s]),
626 "".join([str(x) for x in p]),
627 "".join([str(x) for x in d]),
628 "x".join([str(x) for x in os]),
629 ),
630 [s, p, d, os],
631 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800632 )
Les Bell7aa69f42021-09-20 10:44:07 +0100633 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700634
635 return arg_list
636
637 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100638 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700639 arg_list = []
640 rank = len(shapeList[0])
641
Les Bell7ffccce2021-07-28 15:37:02 +0100642 # Exhaustively test combinations of padding on each side of each dimension
643 # - the range of padding values is defined by pad_min and pad_max
644 # - for padding >9, the name format needs to be more distinctive
645 pad_min, pad_max = 0, 1
646 pad_values = [x for x in range(pad_min, pad_max + 1)]
647 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
648 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700649
Les Bell7ffccce2021-07-28 15:37:02 +0100650 for paddings in shape_pad_values:
651 name = "pad"
652 for r in range(rank):
653 before, after = paddings[r]
654 name = f"{name}{before}{after}"
655 arg_list.append((name, [np.array(paddings)]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700656
657 return arg_list
658
659 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100660 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700661 arg_list = []
662
663 shape = shapeList[0]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100664 if error_name != ErrorIf.WrongRank:
665 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700666
Les Bell7aa69f42021-09-20 10:44:07 +0100667 # Generate comprehensive argument lists
668 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
669 paddings = {x for x in itertools.product(*([p_vals] * 4))}
670 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
671 strides = {x for x in itertools.product(*([s_vals] * 2))}
672 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
673 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700674
Les Bell7aa69f42021-09-20 10:44:07 +0100675 # add some oversize argument values
676 bigStride = 7
677 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
678 bigKernel = 6
679 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
680 if max(shape) < 64:
681 # padding must be less than the kernel size
682 bigPadding = bigKernel - 1
683 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700684
Les Bell7aa69f42021-09-20 10:44:07 +0100685 # There are too many parameter combinations, so generate them sparsely
686 sparsity = len(paddings) * len(strides) * len(kernels) // 500 + 1
687 n = 0
688 for s in sorted(list(strides)):
689 for p in sorted(list(paddings)):
690 for k in sorted(list(kernels)):
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100691 # Calculate output height to test for error_if conditions
692 oh = (shape[1] + p[0] + p[1] + s[0] - k[0]) // s[0]
693 ow = (shape[2] + p[2] + p[3] + s[1] - k[1]) // s[1]
694 y = (oh * s[0]) - p[0] - p[1] - s[0] + k[0]
695 x = (ow * s[1]) - p[2] - p[3] - s[1] + k[1]
696
697 if error_name in [ErrorIf.StrideSmallerOne, ErrorIf.KernelSmallerOne, ErrorIf.PadSmallerZero, ErrorIf.PadLargerEqualKernel]:
698 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(testGen, error_name, s, p, k)
699 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
700 arg_list.append(
701 (
702 "st{}_kern{}_pad{}".format(
703 "".join([str(x) for x in sNew]),
704 "".join([str(x) for x in kNew]),
705 "".join([str(x) for x in pNew]),
706 ),
707 [sNew, pNew, kNew],
708 )
709 )
710 elif (n % sparsity == 0
Les Bell7aa69f42021-09-20 10:44:07 +0100711 # padding must not exceed the kernel size
712 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
713 # the padded shape must exceed the kernel size
714 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100715 and y < shape[1] and x < shape[2]
Les Bell7aa69f42021-09-20 10:44:07 +0100716 ):
717 arg_list.append(
718 (
719 "st{}_kern{}_pad{}".format(
720 "".join([str(x) for x in s]),
721 "".join([str(x) for x in k]),
722 "".join([str(x) for x in p]),
723 ),
724 [s, p, k],
725 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800726 )
Les Bell7aa69f42021-09-20 10:44:07 +0100727 n += 1
728
Eric Kunzee5e26762020-10-13 16:11:07 -0700729 return arg_list
730
731 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100732 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700733 arg_list = []
734
735 # Enumerate the output types here
736 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800737 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700738 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800739 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700740 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800741 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700742 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800743 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700744 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800745 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700746 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800747 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700748
749 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800750 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700751
752 return arg_list
753
754 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100755 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700756 arg_list = []
757
758 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100759 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100760 if dtype in [DType.UINT8, DType.INT8] and error_name == ErrorIf.OutputZeroPointNotZero:
761 continue
762 if inDtype == DType.UINT8 and dtype != DType.INT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100763 # The only output dtype for UINT8 is INT8, skip all other combinations
764 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100765 if inDtype != DType.INT8 and dtype == DType.UINT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100766 # The only input dtype for UINT8 is INT8, skip all other combinations
767 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100768 if error_name == ErrorIf.WrongOutputType and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype):
769 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100770
Kevin Cheng550ccc52021-03-03 11:21:43 -0800771 for scale32 in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100772 if error_name == ErrorIf.ScaleTrue and scale32 == False:
773 continue
774 elif error_name == ErrorIf.ScaleNotTrue and scale32 == True:
775 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800776 for double_round in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100777 if error_name == ErrorIf.ScaleNotTrue and double_round == False:
778 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800779 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700780
Matthew Haddonc2025212021-10-08 21:21:05 +0100781 if inDtype == DType.INT48 and scale32 and error_name != ErrorIf.ScaleTrue:
Eric Kunzee5e26762020-10-13 16:11:07 -0700782 # Illegal condition. Must be scale32=False
783 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100784 if double_round and not scale32 and error_name != ErrorIf.ScaleNotTrue:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100785 # Illegal condition. ERROR_IF(!scale32 && double_round)
786 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700787
Kevin Cheng550ccc52021-03-03 11:21:43 -0800788 arg_list.append(
789 (
790 "out{}_sc{}_dr{}_pc{}".format(
791 DTypeNames[dtype],
792 int(scale32),
793 int(double_round),
794 int(per_channel),
795 ),
796 [dtype, scale32, double_round, per_channel],
797 )
798 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700799
800 return arg_list
801
Kevin Chengaee1fac2020-11-11 13:54:06 -0800802 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100803 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800804 arg_list = []
805
806 if dtype is DType.INT32:
807 for p in range(testGen.args.num_rand_permutations):
808
809 shift = testGen.randInt(0, 32)
810
Kevin Cheng550ccc52021-03-03 11:21:43 -0800811 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800812 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100813 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800814
815 return arg_list
816
817 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100818 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800819 arg_list = []
820
Kevin Cheng550ccc52021-03-03 11:21:43 -0800821 arg_list.append(("roundTrue", [True]))
822 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800823
824 return arg_list
825
Eric Kunzee5e26762020-10-13 16:11:07 -0700826 # Helper function for reshape. Gets some factors of a larger number.
827 @staticmethod
828 def getFactors(val, start=1):
829 factors = []
830
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100831 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700832 if (val % i) == 0:
833 factors.append(i)
834
835 return factors
836
837 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100838 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700839 arg_list = []
840
841 origShape = shapeList[0]
842
843 totalElements = 1
844 for s in origShape:
845 totalElements *= s
846
847 # This code is NOT fast. Fortunately, the numbers are fairly small.
848 factors = TosaArgGen.getFactors(totalElements)
849
850 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100851 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800852 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700853 continue
854
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100855 found = True
856 # escape_counter breaks while loop if it continues on for too long
857 escape_counter = 0
858 while found:
859 newShape = []
860 # Generate newShape ensuring it isn't a duplicate
861 remainingElements = totalElements
862 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100863 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100864 # pick rank-1 factors
865 newShape.append(shuffledFactors[0])
866 remainingElements = remainingElements // shuffledFactors[0]
867 shuffledFactors = testGen.rng.permutation(
868 TosaArgGen.getFactors(remainingElements)
869 )
870 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700871
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100872 # Toss in a -1 sometimes
873 minusOne = testGen.randInt(0, newRank * 4)
874 if minusOne < newRank:
875 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700876
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100877 # Check for duplicates
878 found = False
879 for name, other_shape in arg_list:
880 if other_shape[0] == newShape:
881 found = True
882 break
883
884 escape_counter += 1
885 if escape_counter >= 100:
886 break
887
888 if not found:
889 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700890
891 return arg_list
892
Eric Kunzee5e26762020-10-13 16:11:07 -0700893 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100894 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700895 arg_list = []
896
897 ifm_shape = shapeList[0]
898
Jeremy Johnsona6185572021-06-21 15:55:35 +0100899 # Get all permutations
900 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700901
Jeremy Johnsona6185572021-06-21 15:55:35 +0100902 # Limit to possible permutations from shape dimension or argument setting
903 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700904
Jeremy Johnsona6185572021-06-21 15:55:35 +0100905 # Get random permutation generator that uses all permutations
906 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700907
Jeremy Johnsona6185572021-06-21 15:55:35 +0100908 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -0700909 arg_list = [
910 ("perm{}".format(p), [random_permutations[p].tolist()])
911 for p in range(limit)
912 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700913 return arg_list
914
915 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100916 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700917 arg_list = []
918
919 ifm_shape = shapeList[0]
920 rank = len(ifm_shape)
921
922 for p in range(testGen.args.num_rand_permutations):
923 begin = []
924 size = []
925
Kevin Cheng550ccc52021-03-03 11:21:43 -0800926 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700927
928 for i in range(rank):
929 if ifm_shape[i] > 1:
930 begin.append(testGen.randInt(0, ifm_shape[i]))
931 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
932
933 # Invalid slice size?
934 if size[i] == 0:
935 valid = False
936 else:
937 begin.append(0)
938 size.append(1)
939
940 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800941 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700942 return arg_list
943
944 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100945 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700946 arg_list = []
947
948 ifm_shape = shapeList[0]
949 rank = len(ifm_shape)
950
951 for p in range(testGen.args.num_rand_permutations):
952
953 # Pick a few random, but small multiple values
954 # because otherwise this has a tendency to generate
955 # enormous tensors
956 multiples = []
957 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +0100958 if ifm_shape[i] > 1000:
959 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
960 multiples.append(1)
961 elif max(ifm_shape) > 1000:
962 multiples.append(2)
963 else:
964 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800965 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700966
967 return arg_list
968
969 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100970 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700971 arg_list = []
972
973 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +0100974 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700975
976 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +0100977 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100978 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +0100979 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800980 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +0100981 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100982 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +0100983 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800984 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800985 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800986 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +0100987 elif error_name == ErrorIf.WrongInputType:
988 # If an incorrect input type is used then we set a 'correct'
989 # output type to avoid other errors
990 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700991 else:
992 continue
993
994 for outputDType in outputDTypeList:
995 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -0700996 # Randomly generate legal output dimensions and shift
997 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +0100998 # A output_dim of 1 will cause offset to exceed allowed range
999 # so minimum value 2 produced below
1000 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
1001 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
1002 output_dims[0] += 1
1003 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
1004 output_dims[1] += 1
1005
Kevin Cheng77d0f762020-11-24 10:26:32 -08001006 in_center_h = (ifm_shape[1] - 1) / 2.0
1007 in_center_w = (ifm_shape[2] - 1) / 2.0
1008 out_center_h = (output_dims[0] - 1) / 2.0
1009 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -07001010
Kevin Cheng77d0f762020-11-24 10:26:32 -08001011 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
1012 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
1013 fp_offset_y = in_center_h - fp_stride_y * out_center_h
1014 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -07001015
Kevin Cheng77d0f762020-11-24 10:26:32 -08001016 if outputDType == DType.FLOAT:
1017 shift = 0
1018 stride = [0, 0]
1019 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001020 stride_fp = [fp_stride_y, fp_stride_x]
1021 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +01001022
1023 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001024 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001025 testGen,
1026 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001027 mode,
1028 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001029 shapeList,
1030 outputDType,
1031 shift,
1032 stride,
1033 stride_fp,
1034 offset,
1035 offset_fp
1036 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001037 else:
1038 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001039
Kevin Cheng550ccc52021-03-03 11:21:43 -08001040 arg_list.append(
1041 (
1042 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001043 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001044 output_dims[0],
1045 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001046 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001047 stride_fp[0],
1048 stride_fp[1],
1049 offset_fp[0],
1050 offset_fp[1],
1051 ),
1052 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001053 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001054 stride,
1055 offset,
1056 shift,
1057 stride_fp,
1058 offset_fp,
1059 output_dims,
1060 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001061 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001062 ],
1063 )
1064 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001065 else:
1066 shift = 11
1067 unit = float(1 << shift)
1068 stride_y = int(round(fp_stride_y * unit))
1069 stride_x = int(round(fp_stride_x * unit))
1070 offset_y = int(round(fp_offset_y * unit))
1071 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001072
Kevin Cheng550ccc52021-03-03 11:21:43 -08001073 while (
Matthew Haddone86fd342021-09-07 16:12:21 +01001074 stride_y >= (16 << shift)
1075 or stride_x >= (16 << shift)
1076 or offset_y >= (16 << shift)
1077 or offset_x >= (16 << shift)
1078 or offset_y <= (-16 << shift)
1079 or offset_x <= (-16 << shift)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001080 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001081 shift = shift - 1
1082 unit = float(1 << shift)
1083 stride_y = int(round(fp_stride_y * unit))
1084 stride_x = int(round(fp_stride_x * unit))
1085 offset_y = int(round(fp_offset_y * unit))
1086 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001087
Kevin Cheng550ccc52021-03-03 11:21:43 -08001088 stride = [stride_y, stride_x]
1089 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001090
1091 stride_fp = [0.0, 0.0]
1092 offset_fp = [0.0, 0.0]
1093
Matthew Haddone86fd342021-09-07 16:12:21 +01001094 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001095 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001096 testGen,
1097 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001098 mode,
1099 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001100 shapeList,
1101 outputDType,
1102 shift,
1103 stride,
1104 stride_fp,
1105 offset,
1106 offset_fp
1107 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001108 else:
1109 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001110
Kevin Cheng550ccc52021-03-03 11:21:43 -08001111 arg_list.append(
1112 (
1113 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001114 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001115 shift,
1116 output_dims[0],
1117 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001118 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001119 stride[0],
1120 stride[1],
1121 offset[0],
1122 offset[1],
1123 ),
1124 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001125 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001126 stride,
1127 offset,
1128 shift,
1129 stride_fp,
1130 offset_fp,
1131 output_dims,
1132 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001133 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001134 ],
1135 )
1136 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001137
1138 return arg_list
1139
Matthew Haddon1c00b712021-10-01 15:51:03 +01001140 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001141 # CondIf generates the condition values here.
1142 # Convert to tensors in the build function, along with the
1143 # then and else blocks
1144 arg_list = []
1145
1146 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001147 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001148
1149 return arg_list
1150
Matthew Haddon1c00b712021-10-01 15:51:03 +01001151 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001152 # While loop: 0 iterations, 1, more than 1
1153 arg_list = []
1154
1155 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001156 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001157
1158 return arg_list
1159
Matthew Haddone86fd342021-09-07 16:12:21 +01001160class TosaErrorIfArgGen:
1161
1162 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001163 def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
Matthew Haddone86fd342021-09-07 16:12:21 +01001164
1165 if outputDType == DType.FLOAT:
1166 if error_name == ErrorIf.StrideSmallerEqualZero:
1167 stride_fp = testGen.rng.random(size=[2]) - 2
1168 elif error_name == ErrorIf.ShiftNotZero:
1169 shift = testGen.rng.integers(1, 5)
1170 elif error_name == ErrorIf.StrideLargerDimension:
1171 shape = shapeList[0]
1172 transform_height = testGen.rng.choice([False, True])
1173 if transform_height:
1174 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1175 else:
1176 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1177 else:
1178 if error_name == ErrorIf.StrideSmallerEqualZero:
1179 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1180 elif error_name == ErrorIf.ShiftSmallerOne:
1181 shift = testGen.rng.integers(-3, 1)
1182 if shift <= 0:
1183 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1184 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1185 else:
1186 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1187 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1188 elif error_name == ErrorIf.ShiftLargerEleven:
1189 shift = np.int16(testGen.rng.integers(12, 15))
1190 elif error_name == ErrorIf.StrideLargerDimension:
1191 shape = shapeList[0]
1192 transform_height = testGen.rng.choice([False, True])
1193 if transform_height:
1194 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1195 else:
1196 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1197 elif error_name == ErrorIf.StrideLargerEqualMax:
1198 stride = [(16 << shift) + 1, (16 << shift) + 1]
1199 elif error_name == ErrorIf.OffsetLargerEqualMax:
1200 offset = [(16 << shift) + 1, (16 << shift) + 1]
1201 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1202 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1203
Matthew Haddon1c00b712021-10-01 15:51:03 +01001204
Matthew Haddon848efb42021-09-09 12:30:53 +01001205 if error_name == ErrorIf.WrongOutputType:
1206 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1207 incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1208 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1209 incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
1210 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1211 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
1212 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1213 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
1214 elif dtype == DType.FLOAT:
1215 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
1216 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001217
Matthew Haddon848efb42021-09-09 12:30:53 +01001218 return shift, stride, stride_fp, offset, offset_fp, outputDType
1219
1220 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001221 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
1222 if (error_name == ErrorIf.StrideSmallerOne
1223 # padding must not exceed the kernel size
1224 and pad[0] < kernel[0] and pad[1] < kernel[0] and pad[2] < kernel[1] and pad[3] < kernel[1]):
1225 wrongStride = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1226 return wrongStride, pad, kernel
1227 elif error_name == ErrorIf.PadSmallerZero:
1228 wrongPad = (testGen.rng.choice([-1, -2, -3]),
1229 testGen.rng.choice([-1, -2, -3]),
1230 testGen.rng.choice([-1, -2, -3]),
1231 testGen.rng.choice([-1, -2, -3]))
1232 return stride, wrongPad, kernel
1233 elif error_name == ErrorIf.KernelSmallerOne:
1234 wrongKernel = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1235 return stride, pad, wrongKernel
1236 elif error_name == ErrorIf.PadLargerEqualKernel:
1237 wrongPad = (testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1238 testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1239 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]),
1240 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]))
1241 return stride, wrongPad, kernel
1242 else:
1243 return None, None, None
1244
Matthew Haddonc2025212021-10-08 21:21:05 +01001245 @staticmethod
1246 def eiRescaleWrongOutputType(input_dtype, output_dtype):
1247 if input_dtype == DType.INT8:
1248 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1249 return True
1250 if input_dtype in [DType.INT16, DType.INT32]:
1251 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1252 return True
1253 elif input_dtype == DType.INT48:
1254 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1255 return True
1256 elif input_dtype == DType.UINT8:
1257 if output_dtype != DType.INT8:
1258 return True
1259 return False
1260
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001261
1262 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001263 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1264 # Mess up input/output tensors for ERROR_IF checks
1265 if error_name == "WrongInputList":
1266 add_input = testGen.rng.choice([True, False])
1267 if add_input:
1268 input_list.append('eiDummyInput')
1269 else:
1270 input_list = input_list[:-1]
1271 if error_name == "WrongOutputList":
1272 add_output = testGen.rng.choice([True, False])
1273 if add_output:
1274 output_list.append('eiDummyOutput')
1275 else:
1276 output_list = []
1277 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001278
Matthew Haddonc2025212021-10-08 21:21:05 +01001279 @staticmethod
1280 def eiRestrictDimension(shape, error_name):
1281 # Restrict dimension size if rank is large for WrongRank Error_If
1282 # This will keep the test sizes reasonably small
1283 if error_name == ErrorIf.WrongRank:
1284 if len(shape) > 4:
1285 shape[4] = 1
1286
1287 return shape
1288
Matthew Haddone86fd342021-09-07 16:12:21 +01001289class TosaErrorValidator:
1290
Matthew Haddon848efb42021-09-09 12:30:53 +01001291 @staticmethod
1292 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
1293 # Check ERROR_IF statements
1294
1295 for val_fcn in validator_fcns:
1296 val_result = val_fcn(True, **kwargs)
1297
1298 validator_name = val_result['error_name']
1299 error_result = val_result['error_result']
1300 error_reason = val_result['error_reason']
1301
1302 if error_result:
1303 if error_name == validator_name:
1304 serializer.setExpectedReturnCode(2, error_reason)
1305 else:
1306 print(f"Multiple ERROR_IF checks hit \nError required: {error_name}, Error_produced: {validator_name}")
1307 return None # Return None to delete test if wrong ERROR_IF is hit
1308 else:
1309 if error_name == validator_name:
1310 print(f"No ERROR_IF hit for {error_name}")
1311 return None
1312
1313 @staticmethod
1314 def evWrongInputType(check=False, **kwargs):
1315 all_dtypes = (DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1316
1317 # Find the unsupported input data types
1318 assert 'op' in kwargs
1319 op = kwargs['op']
1320 input_dtypes = op['types']
1321 wrong_input_dtypes = list(set(all_dtypes) - set(input_dtypes))
1322
1323 error_name = ErrorIf.WrongInputType
1324 param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
1325 error_result = False
1326 error_reason = "Input data type not supported for this operator"
1327
1328 if check:
1329 input_dtype = kwargs['input_dtype']
1330 if input_dtype not in input_dtypes:
1331 error_result = True
1332
1333 info_dict = {
1334 "error_name": error_name,
1335 "error_result": error_result,
1336 "error_reason": error_reason,
1337 "param_reqs": param_reqs
1338 }
1339 return info_dict
1340
1341 @staticmethod
1342 def evWrongOutputType(check=False, **kwargs):
1343 error_name = ErrorIf.WrongOutputType
1344 param_reqs = {"rank": None, "dtype": None, "shape": None}
1345 error_result = False
1346 error_reason = "Output data type not supported for this configuration of operator"
1347
1348 if check:
1349 input_dtype = kwargs['input_dtype']
1350 output_dtype = kwargs['output_dtype']
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001351 op = kwargs['op']
Matthew Haddon848efb42021-09-09 12:30:53 +01001352
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001353 if op['op'] == Op.RESIZE:
1354 mode = kwargs['mode']
1355 if (
1356 (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
1357 (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
1358 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1359 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1360 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1361 ):
1362 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001363 elif op['op'] == Op.RESCALE:
1364 if input_dtype == DType.INT8:
1365 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1366 error_result = True
1367 if input_dtype in [DType.INT16, DType.INT32]:
1368 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1369 error_result = True
1370 elif input_dtype == DType.INT48:
1371 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1372 error_result = True
1373 elif input_dtype == DType.UINT8:
1374 if output_dtype != DType.INT8:
1375 error_result = True
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001376 else:
1377 if output_dtype != input_dtype:
1378 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001379
1380 info_dict = {
1381 "error_name": error_name,
1382 "error_result": error_result,
1383 "error_reason": error_reason,
1384 "param_reqs": param_reqs
1385 }
1386 return info_dict
1387
1388 @staticmethod
1389 def evWrongRank(check=False, **kwargs):
1390 all_ranks = (1, 2, 3, 4, 5)
1391
1392 # Make a list of incorrect ranks
1393 assert 'op' in kwargs
1394 op = kwargs['op']
1395 rmin, rmax = op['rank']
1396 rank_range = range(rmin, rmax + 1)
1397 incorrect_ranks = list(set(all_ranks) - set(rank_range))
Matthew Haddonc2025212021-10-08 21:21:05 +01001398 # Remove small incorrect ranks to avoid index errors
1399 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
Matthew Haddon848efb42021-09-09 12:30:53 +01001400 # Set minimum incorrect rank to 3 to avoid index error
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001401 if op['op'] in [Op.RESIZE]:
Matthew Haddon848efb42021-09-09 12:30:53 +01001402 incorrect_ranks = [3, 5]
1403
1404 error_name = ErrorIf.WrongRank
1405 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1406 error_result = False
1407 error_reason = "Rank not supported for this operator"
1408
1409 if check:
1410 input_shape = kwargs['input_shape']
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001411 if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4:
Matthew Haddon848efb42021-09-09 12:30:53 +01001412 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001413 else:
1414 if len(input_shape) not in rank_range:
1415 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001416
1417 info_dict = {
1418 "error_name": error_name,
1419 "error_result": error_result,
1420 "error_reason": error_reason,
1421 "param_reqs": param_reqs
1422 }
1423 return info_dict
1424
1425 @staticmethod
1426 def evWrongInputList(check=False, **kwargs):
1427 error_name = ErrorIf.WrongInputList
1428 param_reqs = {"rank": None, "dtype": None, "shape": None}
1429 error_result = False
1430 error_reason = "Op input list does not match expected input"
1431
1432 if check:
1433 op = kwargs['op']
1434 input_list = kwargs['input_list']
1435 num_operands = kwargs['num_operands']
1436 if len(input_list) != num_operands:
1437 error_result = True
1438
1439 info_dict = {
1440 "error_name": error_name,
1441 "error_result": error_result,
1442 "error_reason": error_reason,
1443 "param_reqs": param_reqs
1444 }
1445 return info_dict
1446
1447 @staticmethod
1448 def evWrongOutputList(check=False, **kwargs):
1449 error_name = ErrorIf.WrongOutputList
1450 param_reqs = {"rank": None, "dtype": None, "shape": None}
1451 error_result = False
1452 error_reason = "Op output list does not match expected output"
1453
1454 if check:
1455 output_list = kwargs['output_list']
1456 # Note this will be incorrect if an operator returns more than one output
1457 if len(output_list) != 1:
1458 error_result = True
1459
1460 info_dict = {
1461 "error_name": error_name,
1462 "error_result": error_result,
1463 "error_reason": error_reason,
1464 "param_reqs": param_reqs
1465 }
1466 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01001467
1468 @staticmethod
1469 def evMaxDimExceeded(check=False, **kwargs):
1470 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01001471 param_reqs = {
1472 "rank": [4,4],
1473 "dtype": [DType.INT8],
1474 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
1475 }
Matthew Haddone86fd342021-09-07 16:12:21 +01001476 error_result = False
1477 error_reason = "At least one maximum dimension is larger than 16384"
1478
1479 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001480 input_shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001481 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
1482 if ((input_shape[1] > 16384) or
1483 (input_shape[2] > 16384) or
1484 (output_shape[0] > 16384) or
1485 (output_shape[1] > 16384)):
1486 error_result = True
1487
1488 info_dict = {
1489 "error_name": error_name,
1490 "error_result": error_result,
1491 "error_reason": error_reason,
1492 "param_reqs": param_reqs
1493 }
1494 return info_dict
1495
1496 @staticmethod
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001497 def evBatchMismatch(check=False, **kwargs):
1498 error_name = ErrorIf.BatchMismatch
1499 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1500 error_result = False
1501 error_reason = "Input batch size not equal to output batch size"
1502
1503 assert 'op' in kwargs
1504 op = kwargs['op']
1505 rmin, rmax = op['rank']
1506 rank_range = range(rmin, rmax + 1)
1507
1508 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001509 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001510 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1511
1512 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
1513 error_result = True
1514
1515 info_dict = {
1516 "error_name": error_name,
1517 "error_result": error_result,
1518 "error_reason": error_reason,
1519 "param_reqs": param_reqs
1520 }
1521 return info_dict
1522
1523 @staticmethod
1524 def evChannelMismatch(check=False, **kwargs):
1525 error_name = ErrorIf.ChannelMismatch
1526 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1527 error_result = False
1528 error_reason = "Input channel size not equal to output channel size"
1529
1530 assert 'op' in kwargs
1531 op = kwargs['op']
1532 rmin, rmax = op['rank']
1533 rank_range = range(rmin, rmax + 1)
1534
1535 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001536 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001537 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1538 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
1539 error_result = True
1540
1541 info_dict = {
1542 "error_name": error_name,
1543 "error_result": error_result,
1544 "error_reason": error_reason,
1545 "param_reqs": param_reqs
1546 }
1547 return info_dict
1548
1549 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001550 def evStrideSmallerEqualZero(check=False, **kwargs):
1551 error_name = ErrorIf.StrideSmallerEqualZero
1552 param_reqs = {"rank": None, "dtype": None, "shape": None}
1553 error_result = False
1554 error_reason = "Stride value smaller than or equal zero"
1555
1556 if check:
1557 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001558 output_dtype = kwargs['output_dtype']
1559 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
1560 stride = kwargs['stride'] # Work around wrong input/output type tests
1561 elif output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001562 stride = kwargs['stride_fp']
Matthew Haddon848efb42021-09-09 12:30:53 +01001563 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1564 stride = kwargs['stride_fp'] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01001565 else:
1566 stride = kwargs['stride']
1567
1568 if min(stride) <= 0:
1569 error_result = True
1570
1571 info_dict = {
1572 "error_name": error_name,
1573 "error_result": error_result,
1574 "error_reason": error_reason,
1575 "param_reqs": param_reqs
1576 }
1577 return info_dict
1578
1579 @staticmethod
1580 def evStrideLargerEqualMax(check=False, **kwargs):
1581 error_name = ErrorIf.StrideLargerEqualMax
1582 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1583 error_result = False
1584 error_reason = "Stride value larger than or equal to maximum value"
1585
1586 if check:
1587 shift = kwargs['shift']
1588 input_dtype = kwargs['input_dtype']
1589 stride = kwargs['stride']
1590 if input_dtype in [DType.INT8, DType.INT16]:
1591 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1592 error_result = True
1593 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1594 error_result = True
1595
1596 info_dict = {
1597 "error_name": error_name,
1598 "error_result": error_result,
1599 "error_reason": error_reason,
1600 "param_reqs": param_reqs
1601 }
1602 return info_dict
1603
1604
1605 @staticmethod
1606 def evStrideLargerDimension(check=False, **kwargs):
1607 error_name = ErrorIf.StrideLargerDimension
1608 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1609 error_result = False
1610 error_reason = "Stride value larger than or equal to H/W dimension"
1611
1612 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001613 shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001614 input_dtype = kwargs['input_dtype']
1615 stride = kwargs['stride_fp']
1616
1617 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1618 error_result = True
1619
1620 info_dict = {
1621 "error_name": error_name,
1622 "error_result": error_result,
1623 "error_reason": error_reason,
1624 "param_reqs": param_reqs
1625 }
1626 return info_dict
1627
1628
1629 @staticmethod
1630 def evOffsetSmallerEqualMin(check=False, **kwargs):
1631 error_name = ErrorIf.OffsetSmallerEqualMin
1632 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1633 error_result = False
1634 error_reason = "Offset value smaller than or equal to minimum value"
1635
1636 if check:
1637 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001638 output_dtype = kwargs['output_dtype']
1639 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001640 offset = kwargs['offset_fp']
1641 else:
1642 offset = kwargs['offset']
1643
1644 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
1645 error_result = True
1646 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
1647 error_result = True
1648
1649 info_dict = {
1650 "error_name": error_name,
1651 "error_result": error_result,
1652 "error_reason": error_reason,
1653 "param_reqs": param_reqs
1654 }
1655 return info_dict
1656
1657 @staticmethod
1658 def evOffsetLargerEqualMax(check=False, **kwargs):
1659 error_name = ErrorIf.OffsetLargerEqualMax
1660 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1661 error_result = False
1662 error_reason = "Offset value larger than or equal to maximum value"
1663
1664 if check:
1665 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001666 output_dtype = kwargs['output_dtype']
1667 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001668 offset = kwargs['offset_fp']
1669 else:
1670 offset = kwargs['offset']
1671
1672 if shift >= 0:
1673 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
1674 error_result = True
1675
1676 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
1677 error_result = True
1678 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
1679 error_result = True
1680
1681 info_dict = {
1682 "error_name": error_name,
1683 "error_result": error_result,
1684 "error_reason": error_reason,
1685 "param_reqs": param_reqs
1686 }
1687 return info_dict
1688
1689 @staticmethod
1690 def evShiftNotZero(check=False, **kwargs):
1691 error_name = ErrorIf.ShiftNotZero
1692 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1693 error_result = False
1694 error_reason = "Shift value must be zero for float input"
1695
1696 if check:
1697 shift = kwargs['shift']
1698 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001699 output_dtype = kwargs['output_dtype']
1700 if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
Matthew Haddone86fd342021-09-07 16:12:21 +01001701 error_result = True
1702
1703 info_dict = {
1704 "error_name": error_name,
1705 "error_result": error_result,
1706 "error_reason": error_reason,
1707 "param_reqs": param_reqs
1708 }
1709 return info_dict
1710
1711
1712 @staticmethod
1713 def evShiftSmallerOne(check=False, **kwargs):
1714 error_name = ErrorIf.ShiftSmallerOne
1715 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1716 error_result = False
1717 error_reason = "Shift value smaller than one"
1718
1719 if check:
1720 shift = kwargs['shift']
1721 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001722 output_dtype = kwargs['output_dtype']
1723 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001724 error_result = True
1725
1726 info_dict = {
1727 "error_name": error_name,
1728 "error_result": error_result,
1729 "error_reason": error_reason,
1730 "param_reqs": param_reqs
1731 }
1732 return info_dict
1733
1734 @staticmethod
1735 def evShiftLargerEleven(check=False, **kwargs):
1736 error_name = ErrorIf.ShiftLargerEleven
1737 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1738 error_result = False
1739 error_reason = "Shift value larger than eleven"
1740
1741 if check:
1742 shift = kwargs['shift']
1743 if shift > 11:
1744 error_result = True
1745
1746 info_dict = {
1747 "error_name": error_name,
1748 "error_result": error_result,
1749 "error_reason": error_reason,
1750 "param_reqs": param_reqs
1751 }
1752 return info_dict
1753
1754
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001755 @staticmethod
1756 def evRankMismatch(check=False, **kwargs):
1757 error_name = ErrorIf.RankMismatch
1758 param_reqs = {"rank": None, "dtype": None, "shape": None}
1759 error_result = False
1760 error_reason = "Input Rank does not match output rank"
1761
1762 if check:
1763 input1_shape = kwargs['input1'].shape
1764 input2_shape = kwargs['input2'].shape
1765 output_shape = kwargs['result_tensor'].shape
1766 if (len(input1_shape) != len(output_shape)) or (len(input2_shape) != len(output_shape)):
1767 error_result = True
1768
1769 info_dict = {
1770 "error_name": error_name,
1771 "error_result": error_result,
1772 "error_reason": error_reason,
1773 "param_reqs": param_reqs
1774 }
1775 return info_dict
1776
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001777 @staticmethod
1778 def evInputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001779 op = kwargs['op']
1780 inputDtypes = op['types'].copy()
1781 if DType.INT8 in inputDtypes:
1782 inputDtypes.remove(DType.INT8)
1783 if DType.UINT8 in inputDtypes:
1784 inputDtypes.remove(DType.UINT8)
1785
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001786 error_name = ErrorIf.InputZeroPointNotZero
1787 param_reqs = {
1788 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001789 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001790 "shape": None
1791 }
1792 error_result = False
1793 error_reason = "Input DType not INT8 and zero point not 0"
1794
1795 if check:
1796 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01001797 if isinstance(kwargs['qinfo'], tuple):
1798 qinfo = kwargs['qinfo']
1799 input_zero_point = qinfo[0]
1800 else:
1801 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1802 qinfo = kwargs['qinfo'].ints
1803 input_zero_point = qinfo[0][1]
1804
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001805 if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
1806 error_result = True
1807
1808 info_dict = {
1809 "error_name": error_name,
1810 "error_result": error_result,
1811 "error_reason": error_reason,
1812 "param_reqs": param_reqs
1813 }
1814 return info_dict
1815
1816
1817 @staticmethod
1818 def evOutputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001819 op = kwargs['op']
1820 inputDtypes = op['types'].copy()
1821 if DType.INT8 in inputDtypes:
1822 inputDtypes.remove(DType.INT8)
1823 if DType.UINT8 in inputDtypes:
1824 inputDtypes.remove(DType.UINT8)
1825
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001826 error_name = ErrorIf.OutputZeroPointNotZero
1827 param_reqs = {
1828 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001829 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001830 "shape": None
1831 }
1832 error_result = False
1833 error_reason = "Output DType not INT8 and zero point not 0"
1834
1835 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001836 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01001837 output_dtype = kwargs['output_dtype']
1838 if isinstance(kwargs['qinfo'], tuple):
1839 qinfo = kwargs['qinfo']
1840 output_zero_point = qinfo[1]
1841 else:
1842 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1843 qinfo = kwargs['qinfo'].ints
1844 output_zero_point = qinfo[1][1]
1845 if op['op'] == Op.AVG_POOL2D:
1846 if input_dtype != DType.INT8 and output_zero_point != 0:
1847 error_result = True
1848 elif output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001849 error_result = True
1850
1851 info_dict = {
1852 "error_name": error_name,
1853 "error_result": error_result,
1854 "error_reason": error_reason,
1855 "param_reqs": param_reqs
1856 }
1857 return info_dict
1858
Matthew Haddond6ce7252021-09-29 15:35:44 +01001859 @staticmethod
1860 def evAxisSmallerZero(check=False, **kwargs):
1861 error_name = ErrorIf.AxisSmallerZero
1862 param_reqs = {"rank": None, "dtype": None, "shape": None}
1863 error_result = False
1864 error_reason = "Axis smaller than zero"
1865
1866 if check:
1867 axis = kwargs['axis']
1868 if axis < 0:
1869 error_result = True
1870
1871 info_dict = {
1872 "error_name": error_name,
1873 "error_result": error_result,
1874 "error_reason": error_reason,
1875 "param_reqs": param_reqs
1876 }
1877 return info_dict
1878
1879
1880 @staticmethod
1881 def evAxisLargerRank(check=False, **kwargs):
1882 error_name = ErrorIf.AxisLargerRank
1883 param_reqs = {"rank": None, "dtype": None, "shape": None}
1884 error_result = False
1885 error_reason = "Axis larger than rank"
1886
1887 if check:
1888 axis = kwargs['axis']
1889 shape = kwargs['input_shape']
1890 if axis > len(shape):
1891 error_result = True
1892
1893 info_dict = {
1894 "error_name": error_name,
1895 "error_result": error_result,
1896 "error_reason": error_reason,
1897 "param_reqs": param_reqs
1898 }
1899 return info_dict
1900
1901
1902 @staticmethod
1903 def evShapeOfAxisNotOne(check=False, **kwargs):
1904 error_name = ErrorIf.ShapeOfAxisNotOne
1905 param_reqs = {"rank": None, "dtype": None, "shape": None}
1906 error_result = False
1907 error_reason = "shape[axis] is not equal to 1"
1908
1909 if check:
1910 axis = kwargs['axis']
1911 shape = kwargs['output_shape']
1912 if (0 <= axis < len(shape)) and shape[axis] != 1:
1913 error_result = True
1914
1915 info_dict = {
1916 "error_name": error_name,
1917 "error_result": error_result,
1918 "error_reason": error_reason,
1919 "param_reqs": param_reqs
1920 }
1921 return info_dict
1922
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001923
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001924 @staticmethod
1925 def evPadSmallerZero(check=False, **kwargs):
1926 error_name = ErrorIf.PadSmallerZero
1927 param_reqs = {"rank": None, "dtype": None, "shape": None}
1928 error_result = False
1929 error_reason = "At least one pad is smaller than zero"
1930
1931 if check:
1932 pad = kwargs['pad']
1933 if min(pad) < 0:
1934 error_result = True
1935
1936 info_dict = {
1937 "error_name": error_name,
1938 "error_result": error_result,
1939 "error_reason": error_reason,
1940 "param_reqs": param_reqs
1941 }
1942 return info_dict
1943
1944
1945 @staticmethod
1946 def evPadLargerEqualKernel(check=False, **kwargs):
1947 error_name = ErrorIf.PadLargerEqualKernel
1948 param_reqs = {"rank": None, "dtype": None, "shape": None}
1949 error_result = False
1950 error_reason = "At least one pad is larger than kernel dimension"
1951
1952 if check:
1953 pad = kwargs['pad']
1954 kernel = kwargs['kernel']
1955 if min(pad) > 0 and min(kernel) > 1:
1956 if pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]:
1957 error_result = True
1958
1959 info_dict = {
1960 "error_name": error_name,
1961 "error_result": error_result,
1962 "error_reason": error_reason,
1963 "param_reqs": param_reqs
1964 }
1965 return info_dict
1966
1967 @staticmethod
1968 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1969 error_name = ErrorIf.PoolingOutputShapeMismatch
1970 param_reqs = {"rank": None, "dtype": None, "shape": None}
1971 error_result = False
1972 error_reason = "Mismatch between output shape provided and expected output shape"
1973
1974 if check:
1975 pad = kwargs['pad']
1976 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1977
1978 kernel = kwargs['kernel']
1979 kernel_y, kernel_x = kernel[0], kernel[1]
1980
1981 input_shape = kwargs['input_shape']
1982 IH, IW = input_shape[1], input_shape[2]
1983
1984 output_shape = kwargs['output_shape']
1985 OH, OW = output_shape[1], output_shape[2]
1986
1987 stride = kwargs['stride']
1988 stride_y, stride_x = stride[0], stride[1]
1989
1990 # calculate correct height, width dimensions
1991 if stride_x != 0 and stride_y != 0:
1992 y_correct = (IH + pad_top + pad_bottom + stride_y - kernel_y) // stride_y
1993 x_correct = (IW + pad_left + pad_right + stride_x - kernel_x) // stride_x
1994
1995 # ensure parameters are valid
1996 params_valid = (min(kernel) >= 1 and min(stride) >= 1 and min(pad) >= 0
1997 and not (pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]))
1998
1999 if params_valid and (OH != y_correct or OW != x_correct):
2000 error_result = True
2001
2002 info_dict = {
2003 "error_name": error_name,
2004 "error_result": error_result,
2005 "error_reason": error_reason,
2006 "param_reqs": param_reqs
2007 }
2008 return info_dict
2009
2010
2011 @staticmethod
2012 def evKernelSmallerOne(check=False, **kwargs):
2013 error_name = ErrorIf.KernelSmallerOne
2014 param_reqs = {"rank": None, "dtype": None, "shape": None}
2015 error_result = False
2016 error_reason = "At least one kernel dimension is smaller than zero"
2017
2018 if check:
2019 kernel = kwargs['kernel']
2020 if min(kernel) < 1:
2021 error_result = True
2022
2023 info_dict = {
2024 "error_name": error_name,
2025 "error_result": error_result,
2026 "error_reason": error_reason,
2027 "param_reqs": param_reqs
2028 }
2029 return info_dict
2030
2031 @staticmethod
2032 def evStrideSmallerOne(check=False, **kwargs):
2033 error_name = ErrorIf.StrideSmallerOne
2034 param_reqs = {"rank": None, "dtype": None, "shape": None}
2035 error_result = False
2036 error_reason = "At least one stride dimension is smaller than zero"
2037
2038 if check:
2039 stride = kwargs['stride']
2040 if min(stride) < 1:
2041 error_result = True
2042
2043 info_dict = {
2044 "error_name": error_name,
2045 "error_result": error_result,
2046 "error_reason": error_reason,
2047 "param_reqs": param_reqs
2048 }
2049 return info_dict
2050
Matthew Haddonc2025212021-10-08 21:21:05 +01002051 @staticmethod
2052 def evScaleTrue(check=False, **kwargs):
2053 error_name = ErrorIf.ScaleTrue
2054 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
2055 error_result = False
2056 error_reason = "Scale set to true but input type is INT48"
2057
2058 if check:
2059 input_dtype = kwargs['input_dtype']
2060 scale32 = kwargs['scale32']
2061 if scale32 and input_dtype == DType.INT48:
2062 error_result = True
2063
2064 info_dict = {
2065 "error_name": error_name,
2066 "error_result": error_result,
2067 "error_reason": error_reason,
2068 "param_reqs": param_reqs
2069 }
2070 return info_dict
2071
2072 @staticmethod
2073 def evScaleNotTrue(check=False, **kwargs):
2074 error_name = ErrorIf.ScaleNotTrue
2075 param_reqs = {"rank": None, "dtype": None, "shape": None}
2076 error_result = False
2077 error_reason = "Scale set to false but double round set to true"
2078
2079 if check:
2080 scale32 = kwargs['scale32']
2081 double_round = kwargs['double_round']
2082 if not scale32 and double_round:
2083 error_result = True
2084
2085 info_dict = {
2086 "error_name": error_name,
2087 "error_result": error_result,
2088 "error_reason": error_reason,
2089 "param_reqs": param_reqs
2090 }
2091 return info_dict
2092
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002093
2094
Matthew Haddonb724efc2021-08-25 16:40:29 +01002095class TosaInvalidValidator:
2096
2097 @staticmethod
2098 def ivWrongDataTypeOrModeResize(**kwargs):
2099 input_dtype = kwargs["input_dtype"]
2100 args = kwargs["args"]
2101 mode = args[0]
2102 stride = args[1]
2103 stride_fp = args[4]
2104 output_dtype = args[8]
2105
2106 if mode == ResizeMode.BILINEAR:
2107 # Invalid output data type / Invalid input datatype
2108 return (
2109 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
2110 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
2111 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
2112 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2113 )
2114 elif mode == ResizeMode.NEAREST:
2115 # Invalid output data type / Invalid input datatype
2116 return (
2117 (input_dtype != output_dtype) or
2118 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2119 )
2120 else:
2121 # Invalid resize mode
2122 return True
2123
2124 @staticmethod
2125 def ivBadStride(**kwargs):
2126 input_dtype = kwargs["input_dtype"]
2127 args = kwargs["args"]
2128 stride_x = args[1][0]
2129 stride_y = args[1][1]
2130 stride_fp_x = args[4][0]
2131 stride_fp_y = args[4][1]
2132
2133 if input_dtype == DType.FLOAT:
2134 if stride_fp_x <= 0 or stride_fp_y <= 0:
2135 # Negative or zero stride
2136 return True
2137 else:
2138 if stride_x <= 0 or stride_y <= 0:
2139 # Negative or zero stride
2140 return True
2141 return False
2142
2143
Matthew Haddonb724efc2021-08-25 16:40:29 +01002144 @staticmethod
2145 def ivHeightWidthSmallerZero(**kwargs):
2146 opName = kwargs['opName']
2147
2148 inputShapes = kwargs['shapeList']
2149 input = inputShapes[0]
2150 if not opName.endswith("pool2d"):
2151 filter = inputShapes[1]
2152
2153 args = kwargs['args']
2154 strides = args[0]
2155 padding = args[1]
2156 dilations = args[2]
2157 if opName.endswith("pool2d"):
2158 kernel = args[2]
2159
2160 if opName.startswith('conv2d'):
2161 h = (
2162 input[1]
2163 - filter[1]
2164 - (filter[1] - 1) * (dilations[0] - 1)
2165 + padding[0]
2166 + padding[1]
2167 ) // strides[0] + 1
2168
2169 w = (
2170 input[2]
2171 - filter[2]
2172 - (filter[2] - 1) * (dilations[1] - 1)
2173 + padding[2]
2174 + padding[3]
2175 ) // strides[1] + 1
2176 elif opName.startswith("depthwise_conv2d"):
2177 h = (
2178 input[1]
2179 - filter[0]
2180 - (filter[0] - 1) * (dilations[0] - 1)
2181 + padding[0]
2182 + padding[1]
2183 ) // strides[0] + 1
2184
2185 w = (
2186 input[2]
2187 - filter[1]
2188 - (filter[1] - 1) * (dilations[1] - 1)
2189 + padding[2]
2190 + padding[3]
2191 ) // strides[1] + 1
2192 elif opName.endswith("pool2d"):
2193 h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
2194 w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
2195 else:
2196 assert False, "Unrecognized Op"
2197
2198 if h <= 0 or w <= 0:
2199 # Invalid parameter combination
2200 return True
2201 return False
2202
2203 @staticmethod
2204 def ivNonPositiveOutputShape(**kwargs):
2205 args = kwargs['args']
2206 output_shape = args[3]
2207 if output_shape[1] <= 0 or output_shape[2] <= 0:
2208 # Negative output shape
2209 return True
2210 return False
2211
2212
Kevin Cheng550ccc52021-03-03 11:21:43 -08002213
Eric Kunzee5e26762020-10-13 16:11:07 -07002214class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002215 # Maximum rank of tensor supported by test generator.
2216 TOSA_TENSOR_MAX_RANK = 6
2217
Eric Kunzee5e26762020-10-13 16:11:07 -07002218 def __init__(self, args):
2219 self.args = args
2220 self.basePath = args.output_dir
2221 self.random_seed = args.random_seed
2222 self.ser = None
2223 self.rng = np.random.default_rng(self.random_seed)
2224 self.createDynamicOpLists()
2225 self.initOpListDefaults()
2226 self.quantGen = TosaQuantGen()
2227 # Force makeShape to do a specific starting shape
2228 self.targetted_shape = None
2229
2230 def createSerializer(self, opName, testPath):
2231 self.testPath = os.path.join(opName, testPath)
2232
2233 fullPath = os.path.join(self.basePath, self.testPath)
2234 os.makedirs(fullPath, exist_ok=True)
2235 self.ser = ts.TosaSerializer(fullPath)
2236
2237 def getSerializer(self):
2238 return self.ser
2239
2240 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002241 with open(
2242 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
2243 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07002244 fd.write(self.ser.serialize())
2245
Kevin Cheng550ccc52021-03-03 11:21:43 -08002246 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
2247 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07002248
Matthew Haddon74567092021-07-16 15:38:20 +01002249 def resetRNG(self, seed=None):
2250 if seed == None:
2251 seed = self.random_seed + 1
2252 self.rng = np.random.default_rng(seed)
2253
Eric Kunzee5e26762020-10-13 16:11:07 -07002254 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07002255 if dtype == DType.BOOL:
2256 np_dt = np.bool
2257 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07002258 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07002259 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07002260 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002261 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002262 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
2263 elif dtype == DType.UINT8:
2264 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002265 elif dtype == DType.INT16:
2266 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
2267 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002268 return np.int32(
2269 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
2270 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002271 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002272 return np.int64(
2273 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
2274 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002275 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002276 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002277 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002278 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002279
Kevin Cheng989cb052021-04-28 16:29:44 -07002280 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07002281 placeholders = []
2282
Kevin Cheng989cb052021-04-28 16:29:44 -07002283 assert len(shape_list) == len(dtype_list)
2284
2285 for idx, shape in enumerate(shape_list):
2286 arr = self.getRandTensor(shape, dtype_list[idx])
2287 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002288
2289 return placeholders
2290
Kevin Cheng989cb052021-04-28 16:29:44 -07002291 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07002292 consts = []
2293
Kevin Cheng989cb052021-04-28 16:29:44 -07002294 assert len(shape_list) == len(dtype_list)
2295
2296 for idx, shape in enumerate(shape_list):
2297 arr = self.getRandTensor(shape, dtype_list[idx])
2298 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002299
2300 return consts
2301
2302 def makeShape(self, rank):
2303 if self.targetted_shape:
2304 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002305 return np.int32(
2306 self.rng.integers(
2307 low=self.args.tensor_shape_range[0],
2308 high=self.args.tensor_shape_range[1],
2309 size=rank,
2310 )
2311 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002312
2313 def setTargetShape(self, shape):
2314 self.targetted_shape = shape
2315
2316 def randInt(self, low=0, high=256):
2317 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
2318
2319 def getRandNumberDType(self, dtype):
2320 if dtype == DType.FLOAT:
2321 return self.rng.random()
2322 elif dtype == DType.BOOL:
2323 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07002324 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07002325 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07002326 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07002327 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002328 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07002329 elif dtype == DType.INT16:
2330 low, high = (-32768, 32768)
2331 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002332 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07002333 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002334 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07002335 # Special size
2336 return np.int64(self.rng.integers(low, high, size=1))[0]
2337 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002338 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002339
2340 return np.int32(self.rng.integers(low, high, size=1))[0]
2341
2342 def shapeStr(self, shape):
2343
2344 sStr = []
2345 # Convert to strings
2346 for i in shape:
2347 sStr.append(str(i))
2348
Kevin Cheng550ccc52021-03-03 11:21:43 -08002349 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002350
2351 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07002352 if isinstance(t, list):
2353 assert len(t) >= 2
2354 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002355 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002356 if t == DType.BOOL:
2357 return "b"
2358 elif t == DType.INT4:
2359 return "i4"
2360 elif t == DType.INT8:
2361 return "i8"
2362 elif t == DType.UINT8:
2363 return "u8"
2364 elif t == DType.INT16:
2365 return "i16"
2366 elif t == DType.INT32:
2367 return "i32"
2368 elif t == DType.INT48:
2369 return "i48"
2370 elif t == DType.FLOAT:
2371 return "float"
2372 else:
2373 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002374
2375 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002376 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08002377 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07002378 return 4
2379 elif t == DType.INT8:
2380 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08002381 elif t == DType.UINT8:
2382 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07002383 elif t == DType.INT16:
2384 return 16
2385 elif t == DType.INT32:
2386 return 32
2387 elif t == DType.INT48:
2388 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +01002389 elif t == DType.FLOAT:
2390 return 32
2391 elif t == DType.BOOL:
2392 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002393 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002394 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002395
2396 # Argument generators
2397 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
2398 # Where the string descriptor is used to generate the test name and
2399 # The build_fcn_arg_list is expanded and passed to the operator test
2400 # build function
2401
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002402 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
2403 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
2404
Matthew Haddon848efb42021-09-09 12:30:53 +01002405 # build_placeholder returns an int, ABS/other ops does not
2406 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002407 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
2408 return result_tens
2409 elif op['op'] == Op.IDENTITY:
2410 self.ser.addOperator(op['op'], a.name, result_tens.name, None, qinfo)
2411 return result_tens
2412
2413 # Ensure new output type has correct qinfo
2414 if error_name == ErrorIf.WrongOutputType:
2415 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
2416 qinfo = ts.TosaSerializerQuantInfo()
2417 qinfo.UnaryQuantInfo(
2418 TosaQuantGen.getQinfo(self, a.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
2419 )
2420
2421 # Invalidate Input/Output list for error if checks.
2422 input_list = [a.name]
2423 output_list = [result_tens.name]
2424 pCount, cCount = op["operands"]
2425 num_operands = pCount + cCount
2426 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2427
2428 TosaErrorValidator.evValidateErrorIfs(
2429 self.ser,
2430 validator_fcns,
2431 error_name,
2432 op=op,
2433 input_dtype=a.dtype,
2434 output_dtype=result_tens.dtype,
2435 qinfo = qinfo,
2436 result_tensor = result_tens,
2437 input_list=input_list,
2438 output_list=output_list,
2439 num_operands=num_operands,
2440 )
2441
2442 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002443 return result_tens
2444
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002445 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
2446 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
2447
2448
2449 # Invalidate Input/Output list for error if checks.
2450 input_list = [a.name, b.name]
2451 output_list = [result_tens.name]
2452 pCount, cCount = op["operands"]
2453 num_operands = pCount + cCount
2454 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2455
2456 TosaErrorValidator.evValidateErrorIfs(
2457 self.ser,
2458 validator_fcns,
2459 error_name,
2460 op=op,
2461 input1 = a,
2462 input2 = b,
2463 input_dtype = a.dtype,
2464 output_dtype = result_tens.dtype,
2465 result_tensor = result_tens,
2466 input_list=input_list,
2467 output_list=output_list,
2468 num_operands=num_operands,
2469 )
2470
2471 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07002472 return result_tens
2473
2474 def build_binary_nonbroadcast(self, op, a, b):
2475 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002476 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002477 return result_tens
2478
Kevin Chengaee1fac2020-11-11 13:54:06 -08002479 def build_arithmetic_right_shift(self, op, a, b, round):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002480 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002481
2482 attr = ts.TosaSerializerAttribute()
2483 attr.ArithmeticRightShiftAttribute(round)
2484
Matthew Haddon848efb42021-09-09 12:30:53 +01002485 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002486 return result_tens
2487
2488 def build_mul(self, op, a, b, shift):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002489 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Eric Kunzee5e26762020-10-13 16:11:07 -07002490
2491 # Special for multiply:
2492 # Force the result to INT32 for INT types
2493 if a.dtype != DType.FLOAT:
2494 result_tens.setDtype(DType.INT32)
2495
Kevin Chengaee1fac2020-11-11 13:54:06 -08002496 attr = ts.TosaSerializerAttribute()
2497 attr.MulAttribute(shift)
2498
Matthew Haddon848efb42021-09-09 12:30:53 +01002499 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002500 return result_tens
2501
2502 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002503 # Constant size depending on type, random values
2504 if a.dtype == DType.INT16:
Kevin Chengacb550f2021-06-29 15:32:19 -07002505 table_dtype = DType.INT16
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002506 table_arr = self.getRandTensor([513], table_dtype)
2507 else:
2508 assert a.dtype == DType.INT8
2509 table_dtype = DType.INT8
2510 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002511
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002512 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
2513 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01002514 self.ser.addOperator(op['op'], [a.name, table_tens.name], [result_tens.name], None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002515
2516 return result_tens
2517
2518 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07002519 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002520 self.ser.addOperator(op['op'], [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002521 return result_tens
2522
2523 def build_comparison(self, op, a, b):
2524 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002525 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002526 return result_tens
2527
2528 def build_argmax(self, op, a, axis):
2529 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
2530
2531 attr = ts.TosaSerializerAttribute()
2532 attr.AxisAttribute(axis)
2533
Matthew Haddon848efb42021-09-09 12:30:53 +01002534 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002535 return result_tens
2536
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002537 def build_pool2d(self, op, input, stride, pad, kernel, validator_fcns=None, error_name=None, qinfo=None):
2538 result_tens = OutputShaper.pool2dOp(self.ser, self.rng, input, kernel, stride, pad, error_name)
2539
2540 # Ensure new output type has correct qinfo
2541 if error_name == ErrorIf.WrongInputType:
2542 if input.dtype not in [DType.INT8, DType.UINT8]:
2543 qinfo = ts.TosaSerializerQuantInfo()
2544 qinfo.UnaryQuantInfo(
2545 TosaQuantGen.getQinfo(self, input.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
2546 )
2547
2548 # Invalidate Input/Output list for error if checks.
2549 input_list = [input.name]
2550 output_list = [result_tens.name]
2551 pCount, cCount = op["operands"]
2552 num_operands = pCount + cCount
2553 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2554
2555 TosaErrorValidator.evValidateErrorIfs(
2556 self.ser,
2557 validator_fcns,
2558 error_name,
2559 op=op,
2560 input_shape=input.shape,
2561 input_dtype=input.dtype,
2562 output_shape=result_tens.shape,
2563 output_dtype=result_tens.dtype,
2564 kernel=kernel,
2565 stride=stride,
2566 pad=pad,
2567 qinfo = qinfo,
2568 result_tensor = result_tens,
2569 input_list=input_list,
2570 output_list=output_list,
2571 num_operands=num_operands,
2572 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002573
2574 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002575 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07002576
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002577 self.ser.addOperator(op['op'], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002578 return result_tens
2579
2580 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002581 assert len(padding) == 4
2582 result_tens = OutputShaper.conv2dOp(
2583 self.ser, ifm, filter, strides, padding, dilations
2584 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002585
2586 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002587 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07002588
Kevin Cheng550ccc52021-03-03 11:21:43 -08002589 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002590 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002591 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002592 return result_tens
2593
Kevin Cheng1533b852021-09-01 12:51:58 -07002594 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
2595 assert len(padding) == 6
2596 result_tens = OutputShaper.conv3dOp(
2597 self.ser, ifm, filter, strides, padding, dilations
2598 )
2599
2600 attr = ts.TosaSerializerAttribute()
2601 attr.ConvAttribute(padding, strides, dilations)
2602
2603 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002604 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng1533b852021-09-01 12:51:58 -07002605 )
2606 return result_tens
2607
Kevin Cheng550ccc52021-03-03 11:21:43 -08002608 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07002609 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002610 ):
2611 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07002612 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
2613
2614 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002615 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002616
Kevin Cheng550ccc52021-03-03 11:21:43 -08002617 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002618 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002619 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002620 return result_tens
2621
Kevin Cheng550ccc52021-03-03 11:21:43 -08002622 def build_depthwise_conv2d(
2623 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
2624 ):
2625 result_tens = OutputShaper.depthwiseConv2dOp(
2626 self.ser, ifm, filter, strides, padding, dilations
2627 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002628
2629 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002630 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07002631
Kevin Cheng550ccc52021-03-03 11:21:43 -08002632 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002633 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002634 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002635 return result_tens
2636
2637 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
2638 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
2639
Kevin Cheng550ccc52021-03-03 11:21:43 -08002640 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002641 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002642 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002643 return result_tens
2644
2645 def build_matmul(self, op, a, b, qinfo):
2646 result_tens = OutputShaper.matmulOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002647 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002648 return result_tens
2649
Matthew Haddond6ce7252021-09-29 15:35:44 +01002650 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
2651 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
2652
2653 # Invalidate Input/Output list for error if checks.
2654 input_list = [a.name]
2655 output_list = [result_tens.name]
2656 pCount, cCount = op["operands"]
2657 num_operands = pCount + cCount
2658 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2659
2660 TosaErrorValidator.evValidateErrorIfs(
2661 self.ser,
2662 validator_fcns,
2663 error_name,
2664 op=op,
2665 axis = axis,
2666 input_shape = a.shape,
2667 output_shape = result_tens.shape,
2668 input_dtype = a.dtype,
2669 output_dtype = result_tens.dtype,
2670 result_tensor = result_tens,
2671 input_list=input_list,
2672 output_list=output_list,
2673 num_operands=num_operands,
2674 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002675
2676 attr = ts.TosaSerializerAttribute()
2677 attr.AxisAttribute(axis)
2678
Matthew Haddond6ce7252021-09-29 15:35:44 +01002679 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002680 return result_tens
2681
2682 def build_clamp(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002683 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002684
2685 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01002686 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07002687
2688 if a.dtype == DType.FLOAT:
2689 attr.ClampAttribute(0, 0, min(v), max(v))
2690 else:
2691 attr.ClampAttribute(min(v), max(v), 0, 0)
2692
Matthew Haddon848efb42021-09-09 12:30:53 +01002693 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002694 return result_tens
2695
2696 def build_leaky_relu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002697 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002698 attr = ts.TosaSerializerAttribute()
2699
2700 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
2701
Matthew Haddon848efb42021-09-09 12:30:53 +01002702 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002703 return result_tens
2704
2705 # Needs an additional type/input
2706 def build_prelu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002707 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002708
Matthew Haddon848efb42021-09-09 12:30:53 +01002709 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002710 return result_tens
2711
Eric Kunzee5e26762020-10-13 16:11:07 -07002712 def build_sigmoid(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002713 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01002714 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002715 return result_tens
2716
2717 def build_tanh(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002718 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01002719 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002720 return result_tens
2721
Matthew Haddon818ab902021-07-27 09:12:49 +01002722 def build_concat(self, op, *a):
Kevin Cheng93a16282021-08-31 16:14:03 -07002723 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01002724
2725 # To store variable length list of input tensors we need to store axis along with it
2726 axis = a[-1]
2727 a = a[:-1]
2728
2729 result_tens = OutputShaper.concatOp(self.ser, axis, *a)
Eric Kunzee5e26762020-10-13 16:11:07 -07002730
2731 attr = ts.TosaSerializerAttribute()
2732 attr.AxisAttribute(axis)
2733
Matthew Haddon818ab902021-07-27 09:12:49 +01002734 input_tensor_names = []
2735 for tensor in a:
2736 input_tensor_names.append(tensor.name)
2737
Matthew Haddon848efb42021-09-09 12:30:53 +01002738 self.ser.addOperator(op['op'], input_tensor_names, [result_tens.name], attr)
2739 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002740
2741 def build_pad(self, op, a, padding, qinfo):
2742 result_tens = OutputShaper.padOp(self.ser, a, padding)
2743
2744 # Need to turn the padding array into a TOSA tensor here.
2745 # This is one of the few tensor operands that does not get
2746 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08002747 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07002748
Kevin Cheng550ccc52021-03-03 11:21:43 -08002749 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002750 op['op'], [a.name, padding_tens.name], [result_tens.name], None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002751 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002752 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002753
2754 def build_reshape(self, op, a, newShape):
2755 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
2756
2757 attr = ts.TosaSerializerAttribute()
2758 attr.ReshapeAttribute(newShape)
2759
Matthew Haddon848efb42021-09-09 12:30:53 +01002760 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002761 return result_tens
2762
2763 def build_reverse(self, op, a, axis):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002764 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002765
2766 attr = ts.TosaSerializerAttribute()
2767 attr.AxisAttribute(axis)
2768
Matthew Haddon848efb42021-09-09 12:30:53 +01002769 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002770 return result_tens
2771
2772 def build_transpose(self, op, a, perms):
2773 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
2774
Kevin Cheng550ccc52021-03-03 11:21:43 -08002775 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07002776
Matthew Haddon848efb42021-09-09 12:30:53 +01002777 self.ser.addOperator(op['op'], [a.name, perms_tens.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002778 return result_tens
2779
2780 def build_slice(self, op, a, begin, size):
2781 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
2782
2783 attr = ts.TosaSerializerAttribute()
2784 attr.SliceAttribute(begin, size)
2785
Matthew Haddon848efb42021-09-09 12:30:53 +01002786 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002787 return result_tens
2788
2789 def build_tile(self, op, a, multiples):
2790 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
2791
2792 attr = ts.TosaSerializerAttribute()
2793 attr.TileAttribute(multiples)
2794
Matthew Haddon848efb42021-09-09 12:30:53 +01002795 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002796 return result_tens
2797
Kevin Cheng77d0f762020-11-24 10:26:32 -08002798 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07002799
2800 # Create a new indicies tensor
2801 # here with data that doesn't exceed the dimensions of the values tensor
2802
Kevin Cheng550ccc52021-03-03 11:21:43 -08002803 K = values.shape[1] # K
2804 W = self.randInt(
2805 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
2806 ) # W
2807 indicies_arr = np.int32(
2808 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
2809 ) # (N, W)
2810 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002811
Kevin Cheng77d0f762020-11-24 10:26:32 -08002812 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07002813
Matthew Haddon848efb42021-09-09 12:30:53 +01002814 self.ser.addOperator(op['op'], [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002815
2816 return result_tens
2817
Kevin Cheng77d0f762020-11-24 10:26:32 -08002818 def build_scatter(self, op, values_in, input):
2819
2820 # Create a new indicies tensor
2821 # here with data that doesn't exceed the dimensions of the values_in tensor
2822
Kevin Cheng550ccc52021-03-03 11:21:43 -08002823 K = values_in.shape[1] # K
2824 W = input.shape[1] # W
2825 indicies_arr = np.int32(
2826 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
2827 ) # (N, W)
2828 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002829
2830 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
2831
Kevin Cheng550ccc52021-03-03 11:21:43 -08002832 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002833 op['op'], [values_in.name, indicies.name, input.name], [result_tens.name]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002834 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08002835
2836 return result_tens
2837
Matthew Haddon848efb42021-09-09 12:30:53 +01002838
Kevin Cheng550ccc52021-03-03 11:21:43 -08002839 def build_resize(
2840 self,
2841 op,
2842 input,
2843 mode,
2844 stride,
2845 offset,
2846 shift,
2847 stride_fp,
2848 offset_fp,
2849 output_dims,
2850 input_dtype,
2851 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01002852 validator_fcns,
2853 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002854 ):
2855 result_tens = OutputShaper.resizeOp(
2856 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002857 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002858 input,
2859 mode,
2860 stride,
2861 offset,
2862 shift,
2863 stride_fp,
2864 offset_fp,
2865 output_dims,
2866 input_dtype,
2867 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01002868 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08002869 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002870
Matthew Haddon848efb42021-09-09 12:30:53 +01002871 # Invalidate Input/Output list for error if checks.
2872 input_list = [input.name]
2873 output_list = [result_tens.name]
2874 pCount, cCount = op["operands"]
2875 num_operands = pCount + cCount
2876 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
Matthew Haddone86fd342021-09-07 16:12:21 +01002877
Matthew Haddon848efb42021-09-09 12:30:53 +01002878 TosaErrorValidator.evValidateErrorIfs(
2879 self.ser,
2880 validator_fcns,
2881 error_name,
2882 op=op,
2883 mode=mode,
2884 shift=shift,
2885 input_dtype=input_dtype,
2886 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002887 input_shape=input.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002888 output_shape=output_dims,
2889 offset=offset,
2890 offset_fp=offset_fp,
2891 stride=stride,
2892 stride_fp=stride_fp,
2893 input_list=input_list,
2894 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002895 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01002896 num_operands=num_operands,
2897 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002898
Eric Kunzee5e26762020-10-13 16:11:07 -07002899 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08002900
Kevin Cheng550ccc52021-03-03 11:21:43 -08002901 attr.ResizeAttribute(
2902 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
2903 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002904
Matthew Haddon848efb42021-09-09 12:30:53 +01002905 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002906 return result_tens
2907
2908 def build_identityn(self, op, val, val2):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002909 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, None)
2910 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, None)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002911 self.ser.addOperator(
2912 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2913 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002914 return result_tens
2915
Kevin Cheng17e92022021-10-01 14:33:33 -07002916 def build_const(self, op, val):
2917 self.ser.addOutputTensor(val)
2918 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07002919
2920 # Type Conversion
2921 def build_cast(self, op, val, out_dtype):
2922 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01002923 self.ser.addOperator(op['op'], [val.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002924 return result_tens
2925
Matthew Haddonc2025212021-10-08 21:21:05 +01002926 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel, validator_fcns, error_name):
Eric Kunzee5e26762020-10-13 16:11:07 -07002927 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
2928
2929 if per_channel:
2930 nc = val.shape[-1]
2931 else:
2932 nc = 1
2933
2934 in_type_width = self.typeWidth(val.dtype)
2935 out_type_width = self.typeWidth(out_dtype)
2936
Kevin Cheng3a478572021-01-22 17:21:02 -08002937 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002938 input_zp = self.randInt(-128, 128)
2939 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002940 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002941 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07002942 in_type_width = in_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01002943 elif error_name == ErrorIf.InputZeroPointNotZero:
2944 input_zp = self.randInt(-128, 128)
2945 if input_zp == 0:
2946 input_zp = input_zp + self.rng.integers(1, 10)
2947 in_type_width = in_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002948 else:
2949 input_zp = 0
2950
Kevin Cheng3a478572021-01-22 17:21:02 -08002951 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002952 output_zp = self.randInt(-128, 128)
2953 out_type_width = out_type_width + 1
2954 elif out_dtype == DType.UINT8:
2955 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07002956 out_type_width = out_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01002957 elif error_name == ErrorIf.OutputZeroPointNotZero:
2958 output_zp = self.randInt(-128, 128)
2959 if output_zp == 0:
2960 output_zp = output_zp + self.rng.integers(1, 10)
2961 out_type_width = out_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002962 else:
2963 output_zp = 0
2964
2965 # Calculate scale based on:
2966 # scale = a *(2^output_width)/(2^input_width))
2967
2968 a = np.float32(self.rng.random(size=[nc]))
2969 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2970
2971 if scale32:
2972 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002973 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002974 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2975 else:
2976 # Cap the scaling at 2^15 - 1 for scale16
2977 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2978
Kevin Cheng550ccc52021-03-03 11:21:43 -08002979 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002980
2981 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2982 shift_arr = np.int32(np.zeros(shape=[nc]))
2983
2984 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002985 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2986 scale_arr[i], scale32
2987 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002988
Kevin Cheng550ccc52021-03-03 11:21:43 -08002989 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07002990
Matthew Haddonc2025212021-10-08 21:21:05 +01002991 # Invalidate Input/Output list for error if checks.
2992 input_list = [val.name]
2993 output_list = [result_tens.name]
2994 pCount, cCount = op["operands"]
2995 num_operands = pCount + cCount
2996 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2997
2998 qinfo = (input_zp, output_zp)
2999 TosaErrorValidator.evValidateErrorIfs(
3000 self.ser,
3001 validator_fcns,
3002 error_name,
3003 op=op,
3004 input_dtype=val.dtype,
3005 output_dtype=out_dtype,
3006 input_shape=val.shape,
3007 qinfo=qinfo,
3008 scale32 = scale32,
3009 double_round = double_round,
3010 input_list=input_list,
3011 output_list=output_list,
3012 result_tensor=result_tens,
3013 num_operands=num_operands,
3014 )
3015
Eric Kunzee5e26762020-10-13 16:11:07 -07003016 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003017 attr.RescaleAttribute(
3018 input_zp,
3019 output_zp,
3020 multiplier_arr,
3021 shift_arr,
3022 scale32,
3023 double_round,
3024 per_channel,
3025 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003026
Matthew Haddonc2025212021-10-08 21:21:05 +01003027 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003028 return result_tens
3029
3030 def build_cond_if_const(self, op, then_tens, else_tens, cond):
3031 # For cond_if with constants, we're supplied with then/else tensors that we ignore
3032 # (except for the generated shap) and the condition. Build Then/Else blocks
3033 # and fill them with const nodes for the body.
3034
3035 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003036 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07003037
3038 # Make then/else tensors
3039 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01003040 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
3041 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003042
3043 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08003044 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07003045
3046 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003047 then_block = "THEN_BLOCK"
3048 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003049 attr = ts.TosaSerializerAttribute()
3050 attr.CondIfAttribute(then_block, else_block)
3051
3052 # Finally, build the op and the two blocks
Matthew Haddon848efb42021-09-09 12:30:53 +01003053 self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003054
3055 self.ser.startBasicBlock(then_block)
3056 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003057 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003058 self.ser.addOutputTensor(then_tens)
3059
3060 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003061 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003062 self.ser.addOutputTensor(else_tens)
3063
3064 return result_tens
3065
3066 def build_cond_if_binary(self, op, a, b, cond):
3067 # For cond_if with a binary op in the then/else blocks, take a and b and
3068 # alternately add or subtract them based on the condition
3069
3070 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003071 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07003072
Kevin Cheng550ccc52021-03-03 11:21:43 -08003073 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003074
3075 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003076 then_block = "THEN_BLOCK"
3077 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003078 attr = ts.TosaSerializerAttribute()
3079 attr.CondIfAttribute(then_block, else_block)
3080
3081 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003082 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003083 op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08003084 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003085
Les Bell6040b4d2021-10-11 12:50:31 +01003086 if a.dtype in (DType.FLOAT, DType.INT32):
3087 then_op, else_op = Op.ADD, Op.SUB
3088 elif a.dtype in (DType.INT8, DType.INT16):
3089 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
3090 else:
3091 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07003092
Les Bell6040b4d2021-10-11 12:50:31 +01003093 for block, op in ((then_block, then_op), (else_block, else_op)):
3094 self.ser.startBasicBlock(block)
3095 self.ser.addInputTensor(a)
3096 self.ser.addInputTensor(b)
3097 tens = self.ser.addOutput(a.shape, a.dtype)
3098 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003099
3100 return result_tens
3101
3102 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003103 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07003104
Kevin Cheng550ccc52021-03-03 11:21:43 -08003105 cond_block = "COND_BLOCK"
3106 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003107
3108 attr = ts.TosaSerializerAttribute()
3109 attr.WhileLoopAttribute(cond_block, body_block)
3110
3111 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003112 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003113 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08003114 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07003115
3116 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003117 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
3118 a_out = self.ser.addIntermediate(a.shape, a.dtype)
3119 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003120
3121 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08003122 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003123 op['op'],
Kevin Cheng550ccc52021-03-03 11:21:43 -08003124 [iter.name, a.name, acc.name],
3125 [iter_out.name, a_out.name, acc_out.name],
3126 attr,
3127 )
Kevin Chengb227ae52021-09-02 13:43:17 -07003128 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07003129
3130 # COND block (input: iter, output: cond_tens )
3131 self.ser.startBasicBlock(cond_block)
3132 self.ser.addInputTensor(iter)
3133 self.ser.addInputTensor(a)
3134 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003135 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
3136 cond_tens = self.ser.addOutput([], DType.BOOL)
3137 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003138
3139 # BODY block (input: a, acc, iter, output: a, acc, iter)
3140 # Note that local intermediate tensors need to be declared here for the outputs
3141 self.ser.startBasicBlock(body_block)
3142 self.ser.addInputTensor(iter)
3143 self.ser.addInputTensor(a)
3144 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003145 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
3146 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
3147 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003148 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
3149 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
3150 self.ser.addOutputTensor(iter_body_out)
3151 self.ser.addOutputTensor(a)
3152 self.ser.addOutputTensor(acc_body_out)
3153
3154 return acc_out
3155
Matthew Haddon1c00b712021-10-01 15:51:03 +01003156 def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
3157 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
3158 default_test_rank_range = range(1, 5)
3159 if not shapeFilter:
3160 shapeFilter = [None]
3161
3162 # Calculate the filters based on what is requested and what the operator allows
3163 rmin, rmax = op["rank"]
3164 if rankFilter is not None:
3165 cleanRankFilter = []
3166 # Ensure rankFilter values are allowed by operator
3167 for rank in rankFilter:
3168 if rank >= rmin and rank <= rmax:
3169 cleanRankFilter.append(rank)
3170 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01003171 # Ensure default behaviour is bounded by default range or by operator,
3172 # whichever is the smaller range of ranks.
3173 opRankRange = range(rmin, rmax + 1)
3174 cleanRankFilter = opRankRange if len(opRankRange) <= len(default_test_rank_range) else default_test_rank_range
Matthew Haddon1c00b712021-10-01 15:51:03 +01003175 else:
3176 cleanRankFilter = range(rmin, rmax + 1)
3177
3178 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003179
Matthew Haddon1c00b712021-10-01 15:51:03 +01003180 if dtypeFilter is not None:
3181 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01003182 # Create list of operator dtypes filtered by requested dtypes
3183 for dtype in dtypes:
3184 if dtype in dtypeFilter or (isinstance(dtype, list) and dtype[0] in dtypeFilter):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003185 cleanDtypeFilter.append(dtype)
3186 else:
3187 cleanDtypeFilter = dtypes
3188
3189 if testType == 'positive':
3190 filterDict = {
3191 'shapeFilter': shapeFilter,
3192 'rankFilter': cleanRankFilter,
3193 'dtypeFilter': cleanDtypeFilter
3194 }
3195 return filterDict
3196 elif testType == 'negative':
3197 validator_info = validator(check=False, op=op)
3198 error_arguments = validator_info['param_reqs']
3199
3200 #Set parameters as required
3201 if error_arguments['rank'] != None:
3202 rankFilter = error_arguments['rank']
3203 else:
3204 rankFilter = cleanRankFilter
3205
3206 if error_arguments['dtype'] != None:
3207 dtypeFilter = error_arguments['dtype']
3208 else:
3209 dtypeFilter = cleanDtypeFilter
3210
3211 if error_arguments['shape'] != None:
3212 shapeFilter = error_arguments['shape']
3213 else:
3214 shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small
3215
3216 filterDict = {
3217 'shapeFilter': shapeFilter,
3218 'rankFilter': rankFilter,
3219 'dtypeFilter': dtypeFilter
3220 }
3221 return filterDict
3222
3223
Kevin Cheng550ccc52021-03-03 11:21:43 -08003224 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01003225 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08003226 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003227
3228 try:
3229 op = self.TOSA_OP_LIST[opName]
3230 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003231 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003232
3233 # Initialize a new random number generator
3234 self.rng = np.random.default_rng(self.random_seed)
3235
Kevin Cheng550ccc52021-03-03 11:21:43 -08003236 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003237
Eric Kunzee5e26762020-10-13 16:11:07 -07003238 # Test list consists of a tuple of:
3239 # (opName, testNameStr, dtype, shapeList, argumentsList)
3240 testList = []
Matthew Haddon1c00b712021-10-01 15:51:03 +01003241 if testType == 'negative' and "error_if_validators" in op:
3242 error_if_validators = op["error_if_validators"]
3243 else:
3244 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07003245
Matthew Haddon1c00b712021-10-01 15:51:03 +01003246 for validator in error_if_validators:
3247 if validator is not None:
3248 error_name = validator(check=False, op=op)['error_name']
3249 #print("error_name: ", error_name)
3250 else:
3251 error_name = None
3252
3253 filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator)
3254 cleanRankFilter = filterDict['rankFilter']
3255 cleanDtypeFilter = filterDict['dtypeFilter']
3256 cleanShapeFilter = filterDict['shapeFilter']
3257 #print(f"Filters: S {shapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
3258
3259 for r in cleanRankFilter:
Kevin Cheng1533b852021-09-01 12:51:58 -07003260 if opName.startswith("conv3d"):
3261 assert r == 5, "conv3d test must have input rank == 5"
Matthew Haddon1c00b712021-10-01 15:51:03 +01003262 for t in cleanDtypeFilter:
3263 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01003264 # Filter out by rank
3265 if shape is not None and len(shape) != r:
3266 continue
Matthew Haddon74567092021-07-16 15:38:20 +01003267 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003268 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003269
Matthew Haddon74567092021-07-16 15:38:20 +01003270 shapeStr = self.shapeStr(shapeList[0])
3271 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07003272
Matthew Haddon74567092021-07-16 15:38:20 +01003273 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
3274 argList = []
3275 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01003276 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003277 else:
Matthew Haddon74567092021-07-16 15:38:20 +01003278 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07003279
Matthew Haddon74567092021-07-16 15:38:20 +01003280 for argStr, args in argList:
Matthew Haddon1c00b712021-10-01 15:51:03 +01003281 if testType == 'positive':
3282 if argStr:
3283 testStr = "{}_{}_{}_{}".format(
3284 opName, shapeStr, typeStr, argStr
3285 )
3286 else:
3287 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
3288 elif testType == 'negative':
Matthew Haddone86fd342021-09-07 16:12:21 +01003289 if argStr:
3290 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
3291 opName, error_name, shapeStr, typeStr, argStr
3292 )
3293 else:
3294 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003295
3296 testList.append((opName, testStr, t, error_name, shapeList, args))
3297
3298 if testType == 'positive':
3299 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
3300 if "invalid_test_validators" in op:
3301 invalid_test_validators = op["invalid_test_validators"]
3302 clean_testList = []
3303 for test in testList:
3304 for validator_fcn in invalid_test_validators:
3305 remove_test = False
3306 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
3307 remove_test = True
3308 if not remove_test:
3309 clean_testList.append(test)
3310 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07003311
3312 return testList
3313
Matthew Haddone86fd342021-09-07 16:12:21 +01003314
3315 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07003316 try:
3317 op = self.TOSA_OP_LIST[opName]
3318 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003319 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003320
3321 # Create a serializer
3322 self.createSerializer(opName, testStr)
3323
Kevin Cheng550ccc52021-03-03 11:21:43 -08003324 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003325 if "error_if_validators" in op:
3326 error_if_validators = op["error_if_validators"]
3327 else:
3328 error_if_validators = None
3329
Kevin Cheng550ccc52021-03-03 11:21:43 -08003330 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003331 num_operands = pCount + cCount
3332
3333 if isinstance(dtype_or_dtypeList, list):
3334 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07003335 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01003336 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003337 else:
3338 dtypeList = [dtype_or_dtypeList] * (num_operands)
3339
Kevin Cheng93a16282021-08-31 16:14:03 -07003340 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01003341 assert (
3342 len(shapeList) == num_operands
3343 ), "shapeList length {} must match number of operands {}".format(
3344 len(shapeList), num_operands
3345 )
3346 assert (
3347 len(dtypeList) == num_operands
3348 ), "dtypeList length {} must match number of operands {}".format(
3349 len(dtypeList), num_operands
3350 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003351
3352 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003353 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003354 except KeyError:
3355 qgen = None
3356
3357 # Build the random tensor operands and the test
3358 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08003359
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003360 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003361
3362 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003363 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003364 else:
3365 qinfo = None
3366
3367 try:
3368 if error_if_validators is None:
3369 if qinfo is not None:
3370 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
3371 else:
3372 resultName = build_fcn(self, op, *tens, *testArgs)
3373 else:
3374 if qinfo is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003375 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name, qinfo)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003376 else:
3377 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name)
3378 except TypeError as e:
3379 print(
3380 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
3381 build_fcn, tens, testArgs
3382 )
3383 )
3384 raise e
3385
3386 if resultName is None:
3387 print("Invalid ERROR_IF tests created")
3388
3389 # Save the serialized test
3390 self.serialize("test")
3391
3392
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003393 def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003394 pCount, cCount = op["operands"]
3395
3396 tens = []
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003397 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name == None:
Jeremy Johnsonef509a42021-09-07 13:59:47 +01003398 # Make sure the operation does not cause value saturation - where
3399 # the number wraps due to limited number of bits to store the answer
3400 assert (
3401 pCount == 2 and cCount == 0
3402 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01003403 placeholders = []
3404 add = (op["op"] == Op.ADD)
3405 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
3406 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
3407 if add:
3408 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
3409 else:
3410 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
3411
3412 # Work out the saturation limits
3413 max_i32 = (1 << 31)-1
3414 min_i32 = -(1 << 31)
3415 max_arr = np.full(shapeList[1], max_i32)
3416 min_arr = np.full(shapeList[1], min_i32)
3417
3418 # Find how much values exceed the maximum/minimums
3419 sat_max_arr = np.maximum(res_arr - max_arr, 0)
3420 sat_min_arr = np.minimum(res_arr - min_arr, 0)
3421
3422 if not add:
3423 # Swap saturation values and negate values as we need to perform opposite operations
3424 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
3425
3426 # Create new array of unsaturated values by clipping values as needed
3427 b_unsat_arr = b_arr
3428 if (sat_max_arr != 0).any():
3429 # Clip values that cause saturation
3430 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
3431 # Reduce axes in unsaturated tensor to match original tensor
3432 for axis, dim in enumerate(b_arr.shape):
3433 if dim != b_unsat_arr.shape[axis]:
3434 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
3435 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
3436
3437 if (sat_min_arr != 0).any():
3438 # Clip values that cause saturation
3439 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
3440 # Reduce axes in unsaturated tensor to match original tensor
3441 for axis, dim in enumerate(b_arr.shape):
3442 if dim != b_unsat_arr.shape[axis]:
3443 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
3444 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
3445
3446 placeholders.append(
3447 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
3448 )
3449 placeholders.append(
3450 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
3451 )
3452
3453 tens.extend(placeholders)
3454 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
3455 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003456 assert (
3457 pCount == 2 and cCount == 0
3458 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08003459
3460 placeholders = []
3461 for idx, shape in enumerate(shapeList[:]):
3462 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07003463 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08003464 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07003465 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08003466 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07003467 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08003468 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
3469 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003470 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08003471 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003472 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07003473 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08003474
3475 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01003476 elif op["op"] == Op.SELECT:
3477 # Set datatype of condition tensor to boolean
3478 dtypeList[0] = DType.BOOL
3479 tens.extend(
3480 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
3481 )
3482 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003483 elif op["op"] == Op.INTDIV and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003484 assert (
3485 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01003486 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003487
3488 placeholders = []
3489
Matthew Haddon459443c2021-08-23 16:43:13 +01003490 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003491 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07003492 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003493 while True:
3494 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
3495 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
3496
3497 if (divisor_arr == 0).any():
3498 continue
3499
Kevin Cheng47315e12021-05-13 17:41:28 -07003500 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003501 continue
3502
3503 break
3504
3505 placeholders.append(
3506 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
3507 )
3508 placeholders.append(
3509 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
3510 )
3511
3512 tens.extend(placeholders)
3513 elif op["op"] == Op.MUL:
3514 assert (
3515 pCount == 2 and cCount == 0
3516 ), "Op.MUL must have 2 placeholders, 0 consts"
3517
3518 if dtypeList[0] == DType.FLOAT:
3519 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
3520 else:
3521 placeholders = []
3522
3523 # Make sure multiply result in int32 range
3524 shift = testArgs[0]
3525 if dtypeList[0] == DType.INT8:
3526 num_bits = 8
3527 elif dtypeList[0] == DType.INT16:
3528 num_bits = 16
3529 elif dtypeList[0] == DType.INT32:
3530 num_bits = 32
3531 else:
3532 raise Exception("OpMul: invalid input dtype")
3533
3534 for idx, shape in enumerate(shapeList[:]):
3535 low = -(2 ** (num_bits - 1))
3536 high = (2 ** (num_bits - 1)) - 1
3537
3538 a_arr = np.int32(
3539 self.rng.integers(low=low, high=high, size=shapeList[0])
3540 )
3541 b_arr = np.int32(
3542 self.rng.integers(low=low, high=high, size=shapeList[1])
3543 )
3544
3545 i = 0
3546 while True:
3547
3548 a_arr_64 = a_arr.astype(np.int64)
3549 b_arr_64 = b_arr.astype(np.int64)
3550
3551 if shift > 0:
3552 rounding = 1 << (shift - 1)
3553 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
3554 else:
3555 result_arr = a_arr_64 * b_arr_64
3556
3557 if (result_arr > -(2 ** 31)).all() and (
3558 result_arr <= ((2 ** 31) - 1)
3559 ).all():
3560 break
3561
3562 i = i + 1
3563 a_arr = a_arr // 2
3564 b_arr = b_arr // 2
3565
3566 placeholders.append(
3567 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
3568 )
3569 placeholders.append(
3570 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
3571 )
3572
3573 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01003574 elif op["op"] == Op.CONCAT:
3575 count = len(shapeList) - self.args.num_const_inputs_concat
3576 if count < 1:
3577 count = 1
3578 if self.args.num_const_inputs_concat == 0:
3579 count = len(shapeList)
3580
3581 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
3582 tens.extend(
3583 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
3584 )
3585 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08003586 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07003587 tens.extend(
3588 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
3589 )
3590 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07003591
Matthew Haddon1c00b712021-10-01 15:51:03 +01003592 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07003593
3594 def createDynamicOpLists(self):
3595
3596 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07003597 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003598
Kevin Cheng1533b852021-09-01 12:51:58 -07003599 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003600 testName = "conv2d_{}x{}".format(k[0], k[1])
3601 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3602 self.TOSA_OP_LIST[testName]["filter"] = k
3603 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003604
Kevin Cheng550ccc52021-03-03 11:21:43 -08003605 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3606 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3607 "depthwise_conv2d_TEMPLATE"
3608 ].copy()
3609 self.TOSA_OP_LIST[testName]["filter"] = k
3610 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003611
Kevin Cheng550ccc52021-03-03 11:21:43 -08003612 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3613 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3614 "transpose_conv2d_TEMPLATE"
3615 ].copy()
3616 self.TOSA_OP_LIST[testName]["filter"] = k
3617 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003618
Kevin Cheng1533b852021-09-01 12:51:58 -07003619 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3620 for k in KERNELS_3D:
3621 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3622 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3623 self.TOSA_OP_LIST[testName]["filter"] = k
3624 self.TOSA_OP_LIST[testName]["template"] = False
3625
Eric Kunzee5e26762020-10-13 16:11:07 -07003626 # Delete any templates after having created any dynamic ops
3627 # This is a two-pass operation because it's bad practice to delete
3628 # keys from dictionaries while iterating
3629 keyList = []
3630 for k in self.TOSA_OP_LIST:
3631 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003632 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07003633 keyList.append(k)
3634 continue
3635 except KeyError:
3636 pass
3637
3638 for k in keyList:
3639 del self.TOSA_OP_LIST[k]
3640
3641 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003642 """Fill in default fields for ops if they aren't already specified.
3643 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003644 for op in self.TOSA_OP_LIST:
3645
3646 # Required fields
3647 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003648 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003649 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003650 raise Exception(
3651 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3652 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003653
3654 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003655 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003656 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003657 raise Exception(
3658 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3659 op
3660 )
3661 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003662
3663 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003664 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003665 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003666 raise Exception(
3667 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3668 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003669
3670 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003671 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003672 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003673 raise Exception(
3674 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3675 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003676
3677 # Put in default rank range, if missing
3678 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003679 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003680 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003681 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003682
3683 # Tensor operator list
3684 # 'op': op name
3685 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003686 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3687 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003688 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3689 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08003690 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07003691
Kevin Cheng550ccc52021-03-03 11:21:43 -08003692 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
3693 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003694
Kevin Cheng550ccc52021-03-03 11:21:43 -08003695 TYPE_BOOL = [DType.BOOL]
3696 TYPE_FI32 = [DType.FLOAT, DType.INT32]
3697 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
3698 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003699
Kevin Cheng550ccc52021-03-03 11:21:43 -08003700 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07003701
Kevin Cheng1533b852021-09-01 12:51:58 -07003702 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003703 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003704 [DType.INT8, DType.INT8, DType.INT32],
3705 [DType.INT16, DType.INT8, DType.INT48],
3706 DType.FLOAT,
3707 ]
3708
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003709 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003710
3711 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003712 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003713 "argmax": {
3714 "op": Op.ARGMAX,
3715 "operands": (1, 0),
3716 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3717 "types": TYPE_NARROW_INT_FP,
3718 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003719 "avg_pool2d": {
3720 "op": Op.AVG_POOL2D,
3721 "operands": (1, 0),
3722 "rank": (4, 4),
3723 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
3724 "qgen": TosaQuantGen.qgUnary,
3725 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003726 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
3727 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
3728 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
3729 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
3730 TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08003731 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003732 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003733 "conv2d_TEMPLATE": {
3734 "op": Op.CONV2D,
3735 "operands": (1, 2),
3736 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01003737 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003738 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003739 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003740 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003741 "template": True,
3742 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003743 # Templated operator. Filled in by createDynamicOpLists
3744 "conv3d_TEMPLATE": {
3745 "op": Op.CONV3D,
3746 "operands": (1, 2),
3747 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01003748 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07003749 "qgen": TosaQuantGen.qgConv,
3750 "types": TYPE_CONV,
3751 "template": True,
3752 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003753 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003754 "depthwise_conv2d_TEMPLATE": {
3755 "op": Op.DEPTHWISE_CONV2D,
3756 "operands": (1, 2),
3757 "filter": [1, 1],
3758 "rank": (4, 4),
3759 "build_fcn": (
3760 build_depthwise_conv2d,
3761 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01003762 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003763 ),
3764 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003765 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003766 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003767 "template": True,
3768 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003769 "fully_connected": {
3770 "op": Op.FULLY_CONNECTED,
3771 "operands": (1, 2),
3772 "rank": (2, 2),
3773 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
3774 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003775 "types": TYPE_CONV,
Jared Smolens573ecd42021-03-04 15:24:10 -08003776 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003777 "matmul": {
3778 "op": Op.MATMUL,
3779 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003780 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08003781 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
3782 "qgen": TosaQuantGen.qgMatmul,
3783 "types": TYPE_NARROW_INT_FP,
3784 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003785 "max_pool2d": {
3786 "op": Op.MAX_POOL2D,
3787 "operands": (1, 0),
3788 "rank": (4, 4),
3789 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
3790 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003791 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
3792 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
3793 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
3794 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08003795 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003796 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003797 "transpose_conv2d_TEMPLATE": {
3798 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003799 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003800 "rank": (4, 4),
3801 "build_fcn": (
3802 build_transpose_conv2d,
3803 TosaTensorGen.tgTransposeConv2D,
3804 TosaArgGen.agTransposeConv2D,
3805 ),
3806 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003807 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003808 "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003809 "template": True,
3810 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003811 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003812 "clamp": {
3813 "op": Op.CLAMP,
3814 "operands": (1, 0),
3815 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
3816 "types": TYPE_NARROW_INT_FP,
3817 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003818 "sigmoid": {
3819 "op": Op.SIGMOID,
3820 "operands": (1, 0),
3821 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
3822 "types": TYPE_FP,
3823 },
3824 "tanh": {
3825 "op": Op.TANH,
3826 "operands": (1, 0),
3827 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
3828 "types": TYPE_FP,
3829 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003830 # Elementwise Binary Operators
3831 "add": {
3832 "op": Op.ADD,
3833 "operands": (2, 0),
3834 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3835 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003836 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3837 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003838 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003839 "arithmetic_right_shift": {
3840 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3841 "operands": (2, 0),
3842 "build_fcn": (
3843 build_arithmetic_right_shift,
3844 TosaTensorGen.tgBroadcastFuzz,
3845 TosaArgGen.agArithmeticRightShift,
3846 ),
3847 "types": TYPE_INT,
3848 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003849 "bitwise_and": {
3850 "op": Op.BITWISE_AND,
3851 "operands": (2, 0),
3852 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3853 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003854 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3855 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003856 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003857 "bitwise_or": {
3858 "op": Op.BITWISE_OR,
3859 "operands": (2, 0),
3860 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3861 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003862 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3863 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003864 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003865 "bitwise_xor": {
3866 "op": Op.BITWISE_XOR,
3867 "operands": (2, 0),
3868 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3869 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003870 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3871 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003872 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003873 "intdiv": {
3874 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003875 "operands": (2, 0),
3876 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3877 "types": [DType.INT32],
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003878 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3879 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003880 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003881 "logical_and": {
3882 "op": Op.LOGICAL_AND,
3883 "operands": (2, 0),
3884 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3885 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003886 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3887 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003888 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003889 "logical_left_shift": {
3890 "op": Op.LOGICAL_LEFT_SHIFT,
3891 "operands": (2, 0),
3892 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3893 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003894 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3895 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003896 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003897 "logical_right_shift": {
3898 "op": Op.LOGICAL_RIGHT_SHIFT,
3899 "operands": (2, 0),
3900 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3901 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003902 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3903 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003904 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003905 "logical_or": {
3906 "op": Op.LOGICAL_OR,
3907 "operands": (2, 0),
3908 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3909 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003910 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3911 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003912 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003913 "logical_xor": {
3914 "op": Op.LOGICAL_XOR,
3915 "operands": (2, 0),
3916 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3917 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003918 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3919 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003920 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003921 "maximum": {
3922 "op": Op.MAXIMUM,
3923 "operands": (2, 0),
3924 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3925 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003926 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3927 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003928 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003929 "minimum": {
3930 "op": Op.MINIMUM,
3931 "operands": (2, 0),
3932 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3933 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003934 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3935 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003936 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003937 "mul": {
3938 "op": Op.MUL,
3939 "operands": (2, 0),
3940 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
3941 "types": TYPE_INT_FP,
3942 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003943 "pow": {
3944 "op": Op.POW,
3945 "operands": (2, 0),
3946 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
3947 "types": TYPE_FP,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003948 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3949 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003950 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003951 "sub": {
3952 "op": Op.SUB,
3953 "operands": (2, 0),
3954 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3955 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003956 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3957 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003958 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003959 "table": {
3960 "op": Op.TABLE,
3961 # Use the automatic generation functions to create the input array
3962 # but create the table tensor in the build function, as it may be
3963 # a different type from the input
3964 "operands": (1, 0),
3965 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003966 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08003967 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003968 # Elementwise Unary operators
3969 "abs": {
3970 "op": Op.ABS,
3971 "operands": (1, 0),
3972 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3973 "types": TYPE_FI32,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003974 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3975 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003976 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003977 "bitwise_not": {
3978 "op": Op.BITWISE_NOT,
3979 "operands": (1, 0),
3980 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3981 "types": TYPE_INT,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003982 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3983 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003984 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003985 "ceil": {
3986 "op": Op.CEIL,
3987 "operands": (1, 0),
3988 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3989 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003990 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3991 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003992 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003993 "clz": {
3994 "op": Op.CLZ,
3995 "operands": (1, 0),
3996 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3997 "types": [DType.INT32],
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003998 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3999 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004000 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004001 "exp": {
4002 "op": Op.EXP,
4003 "operands": (1, 0),
4004 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4005 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004006 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4007 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004008 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004009 "floor": {
4010 "op": Op.FLOOR,
4011 "operands": (1, 0),
4012 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4013 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004014 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4015 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004016 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004017 "log": {
4018 "op": Op.LOG,
4019 "operands": (1, 0),
4020 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4021 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004022 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4023 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004024 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004025 "logical_not": {
4026 "op": Op.LOGICAL_NOT,
4027 "operands": (1, 0),
4028 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4029 "types": TYPE_BOOL,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004030 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4031 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004032 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004033 "negate": {
4034 "op": Op.NEGATE,
4035 "operands": (1, 0),
4036 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4037 "qgen": TosaQuantGen.qgUnary,
4038 "types": TYPE_INT_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004039 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
4040 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
4041 TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004042 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004043 "reciprocal": {
4044 "op": Op.RECIPROCAL,
4045 "operands": (1, 0),
4046 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4047 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004048 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4049 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004050 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004051 "rsqrt": {
4052 "op": Op.RSQRT,
4053 "operands": (1, 0),
4054 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4055 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004056 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4057 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004058 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004059 # Elementwise Ternary operators
4060 "select": {
4061 "op": Op.SELECT,
4062 "operands": (3, 0),
4063 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
4064 "types": TYPE_FIB,
4065 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004066 # Comparison operators
4067 "equal": {
4068 "op": Op.EQUAL,
4069 "operands": (2, 0),
4070 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4071 "types": TYPE_FI32,
4072 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004073 "greater_equal": {
4074 "op": Op.GREATER_EQUAL,
4075 "operands": (2, 0),
4076 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4077 "types": TYPE_FI32,
4078 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004079 "greater": {
4080 "op": Op.GREATER,
4081 "operands": (2, 0),
4082 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4083 "types": TYPE_FI32,
4084 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004085 # Reduction operators
4086 "reduce_all": {
4087 "op": Op.REDUCE_ALL,
4088 "operands": (1, 0),
4089 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4090 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004091 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4092 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4093 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004094 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004095 "reduce_any": {
4096 "op": Op.REDUCE_ANY,
4097 "operands": (1, 0),
4098 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4099 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004100 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4101 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4102 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004103 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004104 "reduce_max": {
4105 "op": Op.REDUCE_MAX,
4106 "operands": (1, 0),
4107 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4108 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004109 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4110 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4111 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004112 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004113 "reduce_min": {
4114 "op": Op.REDUCE_MAX,
4115 "operands": (1, 0),
4116 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4117 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004118 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4119 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4120 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004121 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004122 "reduce_product": {
4123 "op": Op.REDUCE_PRODUCT,
4124 "operands": (1, 0),
4125 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4126 "types": TYPE_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004127 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4128 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4129 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004130 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004131 "reduce_sum": {
4132 "op": Op.REDUCE_SUM,
4133 "operands": (1, 0),
4134 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4135 "types": TYPE_FI32,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004136 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4137 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4138 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004139 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004140 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004141 "concat": {
4142 "op": Op.CONCAT,
4143 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01004144 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004145 "types": TYPE_FIB,
4146 },
4147 "pad": {
4148 "op": Op.PAD,
4149 "operands": (1, 0),
4150 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
4151 "qgen": TosaQuantGen.qgPad,
4152 "types": TYPE_FIB,
4153 },
4154 "reshape": {
4155 "op": Op.RESHAPE,
4156 "operands": (1, 0),
4157 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
4158 "types": TYPE_FIB,
4159 },
4160 "reverse": {
4161 "op": Op.REVERSE,
4162 "operands": (1, 0),
4163 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4164 "types": TYPE_FIB,
4165 },
4166 "slice": {
4167 "op": Op.SLICE,
4168 "operands": (1, 0),
4169 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
4170 "types": TYPE_FIB,
4171 },
4172 "tile": {
4173 "op": Op.TILE,
4174 "operands": (1, 0),
4175 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
4176 "types": TYPE_FIB,
4177 },
4178 "transpose": {
4179 "op": Op.TRANSPOSE,
4180 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01004181 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004182 "build_fcn": (
4183 build_transpose,
4184 TosaTensorGen.tgBasic,
4185 TosaArgGen.agTranspose,
4186 ),
4187 "types": TYPE_FIB,
4188 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004189 # Data nodes
4190 "const": {
4191 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004192 "operands": (0, 1),
4193 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08004194 "types": TYPE_FIB,
4195 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004196 "identity": {
4197 "op": Op.IDENTITY,
4198 "operands": (1, 0),
4199 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4200 "types": TYPE_FIB,
4201 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004202 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004203 "gather": {
4204 "op": Op.GATHER,
4205 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4206 "operands": (1, 0),
4207 "rank": (3, 3),
4208 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
4209 "types": TYPE_INT_FP,
4210 },
4211 "scatter": {
4212 "op": Op.SCATTER,
4213 # Only specify 'values_in' tensor here.
4214 #'indices' and 'input' are generated in op building stage
4215 "operands": (2, 0),
4216 "rank": (3, 3),
4217 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
4218 "types": TYPE_INT_FP,
4219 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004220 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004221 "resize": {
4222 "op": Op.RESIZE,
4223 "operands": (1, 0),
4224 "rank": (4, 4),
4225 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
4226 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01004227 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
4228 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
4229 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
Matthew Haddon848efb42021-09-09 12:30:53 +01004230 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004231 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
4232 TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004233 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004234 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004235 "cast": {
4236 "op": Op.CAST,
4237 "operands": (1, 0),
4238 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
4239 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
4240 },
4241 "rescale": {
4242 "op": Op.RESCALE,
4243 "operands": (1, 0),
Matthew Haddonc2025212021-10-08 21:21:05 +01004244 "rank": (1,4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004245 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004246 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Matthew Haddonc2025212021-10-08 21:21:05 +01004247 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, TosaErrorValidator.evScaleTrue,
4248 TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4249 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004250 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004251 # Custom
4252 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004253 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004254 # Two varients of cond_if, one that generates one of two constant tensors (no
4255 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4256 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004257 "cond_if_const": {
4258 "op": Op.COND_IF,
4259 "operands": (0, 2),
4260 "build_fcn": (
4261 build_cond_if_const,
4262 TosaTensorGen.tgBasic,
4263 TosaArgGen.agCondIf,
4264 ),
4265 "types": [DType.BOOL],
4266 },
4267 "cond_if_binary": {
4268 "op": Op.COND_IF,
4269 "operands": (2, 0),
4270 "build_fcn": (
4271 build_cond_if_binary,
4272 TosaTensorGen.tgBasic,
4273 TosaArgGen.agCondIf,
4274 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004275 "types": TYPE_INT_FP,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004276 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004277 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004278 "while_loop": {
4279 "op": Op.WHILE_LOOP,
4280 "operands": (0, 1),
4281 "build_fcn": (
4282 build_while_loop,
4283 TosaTensorGen.tgBasic,
4284 TosaArgGen.agWhileLoop,
4285 ),
4286 "types": [DType.INT32],
4287 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004288 }
4289
Kevin Cheng550ccc52021-03-03 11:21:43 -08004290
Eric Kunzee5e26762020-10-13 16:11:07 -07004291class OutputShaper:
4292 # Methods in this class compute the expected output shape and datatype
4293 # for common classes of operations
4294 def __init__(self):
4295 pass
4296
4297 # These methods return arguments that can be used for
4298 # creating a new output tensor
4299 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004300 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4301 if error_name != ErrorIf.RankMismatch:
4302 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004303 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004304
4305 shape = []
4306 for i in range(len(a.shape)):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004307 if a.shape[i] == 1 and error_name == None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004308 shape.append(b.shape[i])
4309 else:
4310 shape.append(a.shape[i])
4311
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004312 if error_name == ErrorIf.WrongOutputType:
4313 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4314 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4315 outputDType = rng.choice(wrong_dtypes)
4316 else:
4317 outputDType = a.dtype
4318
4319 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004320
4321 @staticmethod
4322 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004323 assert len(a.shape) == len(b.shape)
4324 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004325
4326 shape = []
4327 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004328 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004329 shape.append(a.shape[i])
4330
Kevin Cheng550ccc52021-03-03 11:21:43 -08004331 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004332
4333 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004334 def unaryOp(ser, rng, a, error_name=None):
4335 if error_name == ErrorIf.WrongOutputType:
4336 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4337 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4338 outputDType = rng.choice(wrong_dtypes)
4339 else:
4340 outputDType = a.dtype
4341
4342 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004343
4344 @staticmethod
4345 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004346 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
4347 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004348
4349 shape = []
4350 for i in range(len(a.shape)):
4351 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4352
Kevin Cheng550ccc52021-03-03 11:21:43 -08004353 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004354
4355 @staticmethod
4356 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004357 assert len(a.shape) == len(b.shape)
4358 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004359
4360 # Do broadcast
4361 shape = []
4362 for i in range(len(a.shape)):
4363 if a.shape[i] == 1:
4364 shape.append(b.shape[i])
4365 else:
4366 shape.append(a.shape[i])
4367
4368 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08004369 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07004370
4371 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004372 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004373 shape = a.shape.copy()
Matthew Haddond6ce7252021-09-29 15:35:44 +01004374 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ShapeOfAxisNotOne]:
4375 shape[axis] = 1
4376 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4377 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004378
Matthew Haddond6ce7252021-09-29 15:35:44 +01004379 if error_name == ErrorIf.WrongOutputType:
4380 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4381 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4382 outputDType = rng.choice(wrong_dtypes)
4383 else:
4384 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004385
Matthew Haddond6ce7252021-09-29 15:35:44 +01004386 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004387
4388 @staticmethod
4389 def argmaxOp(ser, a, axis):
4390 shape = a.shape.copy()
4391 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004392 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07004393
4394 @staticmethod
4395 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
4396
4397 # IFM: NHWC
4398 # Filter: OHWI
4399 # OFM: NHWC
4400
4401 if len(padding) == 2:
4402 # Expand padding to 4 parameters in the case of transpose_conv2d
4403 # From H,W to T,B,L,R
4404 padding = [padding[0], padding[0], padding[1], padding[1]]
4405
Kevin Cheng550ccc52021-03-03 11:21:43 -08004406 h = (
4407 ifm.shape[1]
4408 - filter.shape[1]
4409 - (filter.shape[1] - 1) * (dilations[0] - 1)
4410 + padding[0]
4411 + padding[1]
4412 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004413
Kevin Cheng550ccc52021-03-03 11:21:43 -08004414 w = (
4415 ifm.shape[2]
4416 - filter.shape[2]
4417 - (filter.shape[2] - 1) * (dilations[1] - 1)
4418 + padding[2]
4419 + padding[3]
4420 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004421
Eric Kunzee5e26762020-10-13 16:11:07 -07004422 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4423
Kevin Cheng3a478572021-01-22 17:21:02 -08004424 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004425 out_dtype = DType.INT32
4426 elif ifm.dtype == DType.INT16:
4427 out_dtype = DType.INT48
4428 elif ifm.dtype == DType.FLOAT:
4429 out_dtype = DType.FLOAT
4430 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004431 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004432
Kevin Cheng550ccc52021-03-03 11:21:43 -08004433 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004434
4435 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -07004436 def conv3dOp(ser, ifm, filter, strides, padding, dilations):
4437
4438 # IFM: NDHWC
4439 # Filter: ODHWI
4440 # OFM: NDHWC
4441
4442 d = (
4443 ifm.shape[1]
4444 - filter.shape[1]
4445 - (filter.shape[1] - 1) * (dilations[0] - 1)
4446 + padding[0]
4447 + padding[1]
4448 ) // strides[0] + 1
4449
4450 h = (
4451 ifm.shape[2]
4452 - filter.shape[2]
4453 - (filter.shape[2] - 1) * (dilations[1] - 1)
4454 + padding[2]
4455 + padding[3]
4456 ) // strides[1] + 1
4457
4458 w = (
4459 ifm.shape[3]
4460 - filter.shape[3]
4461 - (filter.shape[3] - 1) * (dilations[2] - 1)
4462 + padding[4]
4463 + padding[5]
4464 ) // strides[2] + 1
4465
4466 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4467
4468 if ifm.dtype == DType.INT8:
4469 out_dtype = DType.INT32
4470 elif ifm.dtype == DType.INT16:
4471 out_dtype = DType.INT48
4472 elif ifm.dtype == DType.FLOAT:
4473 out_dtype = DType.FLOAT
4474 else:
4475 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
4476
4477 return ser.addOutput(ofm_shape, out_dtype)
4478
4479 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -07004480 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
4481 # IFM: NHWC
4482 # Filter: HWCM
4483 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08004484 h = (
4485 ifm.shape[1]
4486 - filter.shape[0]
4487 - (filter.shape[0] - 1) * (dilations[0] - 1)
4488 + padding[0]
4489 + padding[1]
4490 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004491
Kevin Cheng550ccc52021-03-03 11:21:43 -08004492 w = (
4493 ifm.shape[2]
4494 - filter.shape[1]
4495 - (filter.shape[1] - 1) * (dilations[1] - 1)
4496 + padding[2]
4497 + padding[3]
4498 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004499
Eric Kunzee5e26762020-10-13 16:11:07 -07004500 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4501
Kevin Cheng3a478572021-01-22 17:21:02 -08004502 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004503 out_dtype = DType.INT32
4504 elif ifm.dtype == DType.INT16:
4505 out_dtype = DType.INT48
4506 elif ifm.dtype == DType.FLOAT:
4507 out_dtype = DType.FLOAT
4508 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004509 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004510
Kevin Cheng550ccc52021-03-03 11:21:43 -08004511 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004512
4513 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004514 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004515 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004516 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
4517 # If an incorrect stride is used set dimensions to 0, test is invalid anyway.
4518 h = 1
4519 w = 1
4520 else:
4521 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
4522 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
4523
4524 if error_name == ErrorIf.PoolingOutputShapeMismatch:
4525 choices = [1, 2, 3, 4, 5]
4526 h = h + rng.choice(choices)
4527 w = w + rng.choice(choices)
Eric Kunzee5e26762020-10-13 16:11:07 -07004528
Eric Kunzee5e26762020-10-13 16:11:07 -07004529 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004530
4531 if error_name == ErrorIf.WrongOutputType:
4532 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4533 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4534 outputDType = rng.choice(wrong_dtypes)
4535 else:
4536 outputDType = ifm.dtype
4537
4538 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004539
4540 @staticmethod
4541 def fullyConnectedOp(ser, input, filter):
4542 # input: N, IC
4543 # filter: OC, IC
4544 # output: N, OC
4545
4546 output_shape = [input.shape[0], filter.shape[0]]
4547
Kevin Cheng3a478572021-01-22 17:21:02 -08004548 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004549 out_dtype = DType.INT32
4550 elif input.dtype == DType.INT16:
4551 out_dtype = DType.INT48
4552 elif input.dtype == DType.FLOAT:
4553 out_dtype = DType.FLOAT
4554 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004555 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004556
Kevin Cheng550ccc52021-03-03 11:21:43 -08004557 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004558
4559 @staticmethod
4560 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004561 # a: N, H, C
4562 # b: N, C, W
4563 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004564
Kevin Cheng2d60f002021-06-09 14:18:32 -07004565 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004566
Kevin Cheng3a478572021-01-22 17:21:02 -08004567 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004568 out_dtype = DType.INT32
4569 elif a.dtype == DType.INT16:
4570 out_dtype = DType.INT48
4571 elif a.dtype == DType.FLOAT:
4572 out_dtype = DType.FLOAT
4573 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004574 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004575
Kevin Cheng550ccc52021-03-03 11:21:43 -08004576 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004577
4578 @staticmethod
Matthew Haddon818ab902021-07-27 09:12:49 +01004579 def concatOp(ser, axis, *a):
4580 input1 = a[0]
4581 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004582
Matthew Haddon818ab902021-07-27 09:12:49 +01004583 output_shape = input1.shape.copy()
Eric Kunzee5e26762020-10-13 16:11:07 -07004584
Matthew Haddon818ab902021-07-27 09:12:49 +01004585 output_shape[axis] = input1.shape[axis]
4586
4587 for tensor in remaining_inputs:
4588 output_shape[axis] += tensor.shape[axis]
4589
4590 return ser.addOutput(output_shape, input1.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004591
4592 @staticmethod
4593 def padOp(ser, a, padding):
4594
4595 output_shape = a.shape.copy()
4596
4597 for i in range(len(output_shape)):
4598 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4599
Kevin Cheng550ccc52021-03-03 11:21:43 -08004600 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004601
4602 @staticmethod
4603 def reshapeOp(ser, a, shape):
4604 output_shape = shape.copy()
4605
4606 totalElements = 1
4607 for i in a.shape:
4608 totalElements *= i
4609
4610 # If there are any -1 elements, figure out what that dimension must be
4611 totalOutputElements = 1
4612 for i in output_shape:
4613 if i != -1:
4614 totalOutputElements *= i
4615
4616 # And fill it in
4617 for i in range(len(output_shape)):
4618 if output_shape[i] == -1:
4619 output_shape[i] = totalElements // totalOutputElements
4620
Kevin Cheng550ccc52021-03-03 11:21:43 -08004621 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004622
4623 @staticmethod
4624 def sliceOp(ser, a, begin, size):
4625
4626 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004627 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004628
4629 @staticmethod
4630 def tileOp(ser, a, multiples):
4631
4632 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004633 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004634
4635 for i in range(len(output_shape)):
4636 output_shape[i] = a.shape[i] * multiples[i]
4637
Kevin Cheng550ccc52021-03-03 11:21:43 -08004638 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004639
4640 @staticmethod
4641 def transposeOp(ser, a, perms):
4642 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004643 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004644
4645 for i in range(len(output_shape)):
4646 output_shape[i] = a.shape[perms[i]]
4647
Kevin Cheng550ccc52021-03-03 11:21:43 -08004648 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004649
4650 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08004651 def gatherOp(ser, values, indices):
4652 assert len(values.shape) == 3
4653 assert len(indices.shape) == 2
4654 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004655
Kevin Cheng77d0f762020-11-24 10:26:32 -08004656 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4657
Kevin Cheng550ccc52021-03-03 11:21:43 -08004658 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004659
4660 @staticmethod
4661 def scatterOp(ser, values_in, indices, input):
4662 assert len(values_in.shape) == 3
4663 assert len(indices.shape) == 2
4664 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004665 assert values_in.shape[0] == indices.shape[0] # N
4666 assert input.shape[1] == indices.shape[1] # W
4667 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004668
4669 output_shape = values_in.shape
4670
Kevin Cheng550ccc52021-03-03 11:21:43 -08004671 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004672
4673 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004674 def tableOp(ser, input, table_dtype):
4675 # Same shape as the input, but dtype dependent on table dtype
4676 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
4677 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
4678 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004679
4680 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004681 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004682 serializer,
4683 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004684 input,
4685 mode,
4686 stride,
4687 offset,
4688 shift,
4689 stride_fp,
4690 offset_fp,
4691 output_dims,
4692 input_dtype,
4693 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01004694 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08004695 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01004696 if error_name == ErrorIf.WrongRank:
4697 output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
4698 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004699 if error_name == ErrorIf.BatchMismatch:
4700 output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
4701 elif error_name == ErrorIf.ChannelMismatch:
4702 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
4703 else:
4704 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004705
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004706 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004707
4708 @staticmethod
4709 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004710 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004711
4712 @staticmethod
4713 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08004714 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004715 out_dtype = DType.INT32
4716 elif ifm.dtype == DType.INT16:
4717 out_dtype = DType.INT48
4718 elif ifm.dtype == DType.FLOAT:
4719 out_dtype = DType.FLOAT
4720 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004721 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004722
Kevin Cheng550ccc52021-03-03 11:21:43 -08004723 return ser.addOutput(output_shape, out_dtype)