blob: 1ec4a47a181dae544d4feec00b6fa913ecc5a604 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
48DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080049Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070050ResizeMode = tosa.ResizeMode.ResizeMode()
51
52class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
54
Eric Kunzee5e26762020-10-13 16:11:07 -070055 def __init__(self):
56 pass
57
58 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010059 def getQinfo(testGen, dtype, error_name=None):
60
Les Bell30e46802021-07-23 09:43:31 +010061 if dtype == DType.INT8:
62 return testGen.randInt(-128, 128)
Matthew Haddone4ecdb22021-09-28 11:38:21 +010063 elif dtype == DType.UINT8:
Les Bell30e46802021-07-23 09:43:31 +010064 return testGen.randInt(0, 256)
Matthew Haddonc4cf0372021-10-11 09:38:10 +010065 elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.WeightZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
Matthew Haddone4ecdb22021-09-28 11:38:21 +010066 zero_point = testGen.randInt(-128, 128)
67 if zero_point == 0:
68 zero_point = 1
69 return zero_point
Les Bell30e46802021-07-23 09:43:31 +010070 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070071
72 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010073 def qgUnary(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070074 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone4ecdb22021-09-28 11:38:21 +010075 if error_name == ErrorIf.InputZeroPointNotZero:
76 qinfo.UnaryQuantInfo(
77 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype)
78 )
79 elif error_name == ErrorIf.OutputZeroPointNotZero:
80 qinfo.UnaryQuantInfo(
81 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype, error_name)
82 )
83 else:
84 qinfo.UnaryQuantInfo(
85 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
86 )
Eric Kunzee5e26762020-10-13 16:11:07 -070087 return qinfo
88
89 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010090 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070091 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010092 if isinstance(dtype_or_dtypeList, list):
93 # a list of [input, weights, accumulator] dtypes
94 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070095 else:
Les Bell30e46802021-07-23 09:43:31 +010096 # an int, [input, weights, accumulator] dtypes are the same
97 dtypeList = [dtype_or_dtypeList] * 3
Matthew Haddonc4cf0372021-10-11 09:38:10 +010098
99 if error_name == ErrorIf.InputZeroPointNotZero:
100 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
101 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
102 elif error_name == ErrorIf.WeightZeroPointNotZero:
103 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
104 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
105 else:
106 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
107 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
108
Les Bell30e46802021-07-23 09:43:31 +0100109 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -0700110 return qinfo
111
112 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100113 def qgMatmul(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700114 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100115 if error_name == ErrorIf.InputZeroPointNotZero:
116 qinfo.MatMulQuantInfo(
117 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype, error_name)
Kevin Chengacb550f2021-06-29 15:32:19 -0700118 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100119 else:
120 qinfo.MatMulQuantInfo(
121 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
122 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700123 return qinfo
124
125 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100126 def qgPad(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700127 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +0100128 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700129 return qinfo
130
131 @staticmethod
132 def computeMultiplierAndShift(scaleFp, scale32):
133 # Derived from computeMultiplierAndShiftTosaScale32
134 # Provide a floating-point scaling factor and the scale32 parameter
135 # to compute the multiplier and shift
136
137 if scale32:
138 scaleBits = 31
139 else:
140 scaleBits = 15
141
142 m, shift = math.frexp(scaleFp)
143
144 if scaleFp < 0.0:
145 m = -m
146
147 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800148 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700149
150 if multiplier == (1 << scaleBits):
151 multiplier = multiplier // 2
152 shift = shift + 1
153
154 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100155 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
156
157 # Adjust multiplier such that shift is in allowed value range.
158 if shift == 0:
159 multiplier = multiplier // 4
160 shift = shift + 2
161 elif shift == 1:
162 multiplier = multiplier // 2
163 shift = shift + 1
164 elif shift == 63:
165 multiplier = multiplier * 2
166 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700167
Kevin Cheng550ccc52021-03-03 11:21:43 -0800168 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100169 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700170
171 return multiplier, shift
172
173
Kevin Cheng550ccc52021-03-03 11:21:43 -0800174class TosaTensorGen:
175 """Tensor generators create a shape list for the placeholder and const tensor
176 data operands for the operator. The actual random data is generated separately for each test."""
177
Eric Kunzee5e26762020-10-13 16:11:07 -0700178 def __init__(self):
179 pass
180
181 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100182 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800183 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700184 shape = testGen.makeShape(rank)
185
Matthew Haddonc2025212021-10-08 21:21:05 +0100186 # Constrict dimension size for large ranks when creating WrongRank tests
187 shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
188
Eric Kunzee5e26762020-10-13 16:11:07 -0700189 shape_list = []
190 for i in range(pl + const):
191 shape_list.append(shape.copy())
192
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100193 if error_name == ErrorIf.RankMismatch:
194 if rank == 1 and i != 1:
195 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
196 elif i != 1:
197 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
198
Eric Kunzee5e26762020-10-13 16:11:07 -0700199 return shape_list
200
201 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100202 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800203 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700204
Matthew Haddon848efb42021-09-09 12:30:53 +0100205 if error_name != ErrorIf.WrongRank:
206 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700207
208 shape = testGen.makeShape(rank)
209
210 # Constrict the batch size?
211 if testGen.args.max_batch_size:
212 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100213
214 # Constrict dimension size for large ranks when creating WrongRank tests
215 shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700216
217 shape_list = []
218 for i in range(pl + const):
219 shape_list.append(shape.copy())
220
221 return shape_list
222
223 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100224 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800225 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800226
Kevin Cheng550ccc52021-03-03 11:21:43 -0800227 assert pl == 2
228 assert const == 0
229 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800230
231 values_in_shape = testGen.makeShape(rank)
232
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100233 # ignore max batch size if target shape is set
234 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800235 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
236
Kevin Cheng550ccc52021-03-03 11:21:43 -0800237 W = testGen.randInt(
238 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
239 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100240 # Constrict W if one dimension is too large to keep tensor size reasonable
241 if max(values_in_shape) > 5000:
242 W = testGen.randInt(0, 16)
243
Kevin Cheng77d0f762020-11-24 10:26:32 -0800244 input_shape = [values_in_shape[0], W, values_in_shape[2]]
245
246 shape_list = []
247 shape_list.append(values_in_shape.copy())
248 shape_list.append(input_shape.copy())
249
250 return shape_list
251
252 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100253 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700254 shape = testGen.makeShape(rank)
255
Kevin Cheng550ccc52021-03-03 11:21:43 -0800256 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700257
258 shape_list = []
259
260 # Choose one of the inputs to broadcast
261 bcast_idx = testGen.randInt(0, pl + const)
262 for i in range(pl + const):
263 shape_bcast = shape.copy()
264
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100265 if error_name == ErrorIf.RankMismatch:
266 bcast_idx = -1 # Turn off broadcast because we are not testing it
267 if rank == 1 and i != 1:
268 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
269 elif i != 1:
270 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
271
Eric Kunzee5e26762020-10-13 16:11:07 -0700272 # If the chosen input, pick a random index to broadcast
273 if i == bcast_idx:
274 fuzz_idx = testGen.randInt(0, rank)
275 shape_bcast[fuzz_idx] = 1
276
277 shape_list.append(shape_bcast)
278
279 return shape_list
280
281 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100282 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800283 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700284
Kevin Cheng550ccc52021-03-03 11:21:43 -0800285 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700286
287 # IFM dimensions are NHWC
288 ifm_shape = testGen.makeShape(rank)
289
290 # Constrict the batch size?
291 if testGen.args.max_batch_size:
292 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
293
294 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800295 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700296
297 # Generate a random OFM depth
298 ofm_depth = testGen.makeShape(1)[0]
299
300 # The filter dimensions are OHWI
301 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
302
303 # The bias is OC
304 bias_shape = np.asarray([ofm_depth])
305
306 return [ifm_shape, filter_shape, bias_shape]
307
308 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100309 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700310 pl, const = op["operands"]
311
312 assert rank == 5
313
314 # IFM dimensions are NDHWC
315 ifm_shape = testGen.makeShape(rank)
316
317 # Constrict the batch size?
318 if testGen.args.max_batch_size:
319 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
320
321 # Get the filter depth/height/width from the operator parameters
322 filter_dhw = op["filter"]
323
324 # Generate a random OFM channel
325 ofm_channel = testGen.makeShape(1)[0]
326
327 # The filter dimensions are ODHWI
328 filter_shape = np.asarray(
329 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
330 )
331
332 # The bias is OC
333 bias_shape = np.asarray([ofm_channel])
334
335 return [ifm_shape, filter_shape, bias_shape]
336
337 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100338 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800339 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700340
Kevin Cheng550ccc52021-03-03 11:21:43 -0800341 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700342
343 # IFM dimensions are NHWC
344 ifm_shape = testGen.makeShape(rank)
345
346 # Constrict the batch size?
347 if testGen.args.max_batch_size:
348 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
349
350 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800351 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700352
353 # Generate a random OFM depth
354 ofm_depth = testGen.makeShape(1)[0]
355
356 # The filter dimensions are OHWI
357 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
358
Kevin Cheng989cb052021-04-28 16:29:44 -0700359 # The bias is OC
360 bias_shape = np.asarray([ofm_depth])
361
362 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700363
364 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100365 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800366 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700367
Kevin Cheng550ccc52021-03-03 11:21:43 -0800368 assert rank == 4
369 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700370
371 # IFM dimensions are NHWC
372 ifm_shape = testGen.makeShape(rank)
373
374 # Constrict the batch size?
375 if testGen.args.max_batch_size:
376 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
377
378 # Get the filter height/width from the operator parameters
379 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800380 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700381
382 # Generate a random OFM depth, but don't let it get too big because
383 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800384 filter_m = (
385 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
386 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700387
388 # The filter dimensions are HWCM
389 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
390
391 # The bias is M * C
392 bias_shape = np.asarray([ifm_shape[3] * filter_m])
393
394 return [ifm_shape, filter_shape, bias_shape]
395
396 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100397 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800398 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700399
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100400 if error_name != ErrorIf.WrongRank:
401 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700402
403 input_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100404
405 # Constrict dimension size for large ranks when creating WrongRank tests
406 shape = TosaErrorIfArgGen.eiRestrictDimension(input_shape, error_name)
407
Kevin Chengacb550f2021-06-29 15:32:19 -0700408 filter_oc = testGen.rng.integers(
409 low=testGen.args.tensor_shape_range[0],
410 high=testGen.args.tensor_shape_range[1],
411 size=1,
412 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700413 filter_shape = np.asarray([filter_oc, input_shape[1]])
414
415 bias_shape = np.asarray([filter_oc])
416
417 return [input_shape, filter_shape, bias_shape]
418
419 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100420 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800421 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700422
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100423 if error_name != ErrorIf.WrongRank:
424 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800425 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700426
427 a_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100428
429 # Constrict dimension size for large ranks when creating WrongRank tests
430 shape = TosaErrorIfArgGen.eiRestrictDimension(a_shape, error_name)
431
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100432 # Get a random number for b_oc even if target shape is defined
433 b_oc = np.int32(
434 testGen.rng.integers(
435 low=testGen.args.tensor_shape_range[0],
436 high=testGen.args.tensor_shape_range[1],
437 size=1,
438 )
439 )[0]
440 # If N or H is large let b_oc be 1 to reduce output tensor size
441 if max(a_shape) > 1000:
442 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700443
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100444 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700445 return [a_shape, b_shape]
446
Matthew Haddon818ab902021-07-27 09:12:49 +0100447 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100448 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100449 pl, const = opName["operands"]
450 shape = testGen.makeShape(rank)
451
452 # Create extra tensors to concat.
453 # Take into account value of pl when getting maximum number of concats
454 num_tensors = testGen.randInt(0, 4)
455 shape_list = []
456 for i in range(pl + const + num_tensors):
457 shape_list.append(shape.copy())
458
459 return shape_list
460
461 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100462 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100463 # Split concat shape along axis to allow for multiple const inputs
464 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100465 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddon818ab902021-07-27 09:12:49 +0100466 return shapeList
467
Jeremy Johnson960985a2021-10-06 10:58:14 +0100468 # Create copy of shape we are going to split (so we don't alter shapeList)
469 shape = shapeList[0].copy()
470 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100471 new_shapeList = [shape.copy()]
472 length_on_axis = shape[axis]
473 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700474 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100475 # Calculate split on axis and remaining value
476 split_shape_val = int(shape[axis] / 2)
477 remaining_length = remaining_length - split_shape_val
478
479 # Append new shape, and set remaining shape
480 shape[axis] = split_shape_val
481 new_shapeList.append(shape.copy())
482 shape[axis] = remaining_length
483 if i == len(shapeList) - 3:
484 new_shapeList.append(shape.copy())
485
486 return new_shapeList
487
488
Eric Kunzee5e26762020-10-13 16:11:07 -0700489class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800490 """Argument generators create exhaustive or random lists of attributes for operators that take
491 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
492 tuples where the descriptive_name is appended to the test name and the arglist is expanded
493 as arguments to the operator build function."""
494
Eric Kunzee5e26762020-10-13 16:11:07 -0700495 def __init__(self):
496 pass
497
498 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100499 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800500 """A trivial argument generator for operators that don't take any
501 non-tensor arguments"""
502 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700503
504 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100505 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800506 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700507 axes = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700508 shape = shapeList[0]
509
Matthew Haddond6ce7252021-09-29 15:35:44 +0100510 if error_name == ErrorIf.AxisSmallerZero:
511 small_axis = testGen.rng.integers(-5, 0)
512 axes.append(("axis{}".format(small_axis), [small_axis]))
513 elif error_name == ErrorIf.AxisLargerRank:
514 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
515 axes.append(("axis{}".format(large_axis), [large_axis]))
516 else:
517 for a in range(0, len(shape)):
518 axes.append(("axis{}".format(a), [a]))
519
Eric Kunzee5e26762020-10-13 16:11:07 -0700520 return axes
521
522 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100523 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700524 arg_list = []
525
526 ifm_shape = shapeList[0]
527 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100528 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
529 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700530
Les Bell7aa69f42021-09-20 10:44:07 +0100531 # Check the rank
532 rank = 5 if opName.startswith("conv3d") else 4
533 assert len(ifm_shape) == rank
534 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700535
Les Bell7aa69f42021-09-20 10:44:07 +0100536 # kernel rank omits batch and channels
537 k_rank = rank - 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700538
Les Bell7aa69f42021-09-20 10:44:07 +0100539 # Generate comprehensive argument lists
540 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
541 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
542 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
543 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
544 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
545 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700546
Les Bell7aa69f42021-09-20 10:44:07 +0100547 # add some oversize argument values
548 if max(ifm_shape) < 64:
549 bigPadding = 9
550 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
551 bigStride = 8
552 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
553 bigDilation = 7
554 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100555
556 # There are too many parameter combinations, so generate them sparsely
Les Bell7aa69f42021-09-20 10:44:07 +0100557 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
558 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
559 if sparsity < 13:
560 sparsity = 1
561 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
562 sparsity += 1
Les Bellf414b3c2021-09-06 11:29:46 +0100563 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100564 for s in sorted(list(strides)):
565 for p in sorted(list(paddings)):
566 for d in sorted(list(dilations)):
567 if (n % sparsity == 0
568 # padding must not exceed the kernel size ?
569 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
570 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
571 # the padded shape must exceed the kernel size
572 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
573 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
574 # the padded shape must exceed the dilation
575 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
576 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
577 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100578 arg_list.append(
579 (
580 "st{}_pad{}_dilat{}".format(
581 "".join([str(x) for x in s]),
582 "".join([str(x) for x in p]),
583 "".join([str(x) for x in d]),
584 ),
585 [s, p, d],
586 )
587 )
588 n += 1
589
Kevin Cheng1533b852021-09-01 12:51:58 -0700590 return arg_list
591
592 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100593 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700594 arg_list = []
595
596 ifm_shape = shapeList[0]
597 filter_shape = shapeList[1]
598
599 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800600 assert len(ifm_shape) == 4
601 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700602
Les Bell7aa69f42021-09-20 10:44:07 +0100603 # Generate comprehensive argument lists
604 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
605 paddings = {x for x in itertools.product(*([p_vals] * 2))}
606 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
607 strides = {x for x in itertools.product(*([s_vals] * 2))}
608 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
609 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700610
Les Bell7aa69f42021-09-20 10:44:07 +0100611 # add some oversize argument values
612 if max(ifm_shape) < 64:
613 bigPadding = 9
614 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
615 bigStride = 8
616 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
617 bigDilation = 7
618 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700619
Les Bell7aa69f42021-09-20 10:44:07 +0100620 # There are too many parameter combinations, so generate them sparsely
621 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
622 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
623 if sparsity < 13:
624 sparsity = 1
625 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
626 sparsity += 1
627 n = 0
628 for s in sorted(list(strides)):
629 for p in sorted(list(paddings)):
630 for d in sorted(list(dilations)):
631 if n % sparsity == 0:
632 # Determine the output shape
633 oh = (
634 ifm_shape[1]
635 - filter_shape[1]
636 - (filter_shape[1] - 1) * (d[0] - 1)
637 + 2 * p[0]
638 ) // s[0] + 1
639 ow = (
640 ifm_shape[2]
641 - filter_shape[2]
642 - (filter_shape[2] - 1) * (d[1] - 1)
643 + 2 * p[1]
644 ) // s[1] + 1
645 os = [ifm_shape[0], oh, ow, filter_shape[0]]
646 arg_list.append(
647 (
648 "st{}_pad{}_dilat{}_os{}".format(
649 "".join([str(x) for x in s]),
650 "".join([str(x) for x in p]),
651 "".join([str(x) for x in d]),
652 "x".join([str(x) for x in os]),
653 ),
654 [s, p, d, os],
655 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800656 )
Les Bell7aa69f42021-09-20 10:44:07 +0100657 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700658
659 return arg_list
660
661 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100662 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700663 arg_list = []
664 rank = len(shapeList[0])
665
Les Bell7ffccce2021-07-28 15:37:02 +0100666 # Exhaustively test combinations of padding on each side of each dimension
667 # - the range of padding values is defined by pad_min and pad_max
668 # - for padding >9, the name format needs to be more distinctive
669 pad_min, pad_max = 0, 1
670 pad_values = [x for x in range(pad_min, pad_max + 1)]
671 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
672 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700673
Les Bell7ffccce2021-07-28 15:37:02 +0100674 for paddings in shape_pad_values:
675 name = "pad"
676 for r in range(rank):
677 before, after = paddings[r]
678 name = f"{name}{before}{after}"
679 arg_list.append((name, [np.array(paddings)]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700680
681 return arg_list
682
683 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100684 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700685 arg_list = []
686
687 shape = shapeList[0]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100688 if error_name != ErrorIf.WrongRank:
689 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700690
Les Bell7aa69f42021-09-20 10:44:07 +0100691 # Generate comprehensive argument lists
692 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
693 paddings = {x for x in itertools.product(*([p_vals] * 4))}
694 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
695 strides = {x for x in itertools.product(*([s_vals] * 2))}
696 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
697 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700698
Les Bell7aa69f42021-09-20 10:44:07 +0100699 # add some oversize argument values
700 bigStride = 7
701 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
702 bigKernel = 6
703 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
704 if max(shape) < 64:
705 # padding must be less than the kernel size
706 bigPadding = bigKernel - 1
707 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700708
Les Bell7aa69f42021-09-20 10:44:07 +0100709 # There are too many parameter combinations, so generate them sparsely
710 sparsity = len(paddings) * len(strides) * len(kernels) // 500 + 1
711 n = 0
712 for s in sorted(list(strides)):
713 for p in sorted(list(paddings)):
714 for k in sorted(list(kernels)):
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100715 # Calculate output height to test for error_if conditions
716 oh = (shape[1] + p[0] + p[1] + s[0] - k[0]) // s[0]
717 ow = (shape[2] + p[2] + p[3] + s[1] - k[1]) // s[1]
718 y = (oh * s[0]) - p[0] - p[1] - s[0] + k[0]
719 x = (ow * s[1]) - p[2] - p[3] - s[1] + k[1]
720
721 if error_name in [ErrorIf.StrideSmallerOne, ErrorIf.KernelSmallerOne, ErrorIf.PadSmallerZero, ErrorIf.PadLargerEqualKernel]:
722 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(testGen, error_name, s, p, k)
723 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
724 arg_list.append(
725 (
726 "st{}_kern{}_pad{}".format(
727 "".join([str(x) for x in sNew]),
728 "".join([str(x) for x in kNew]),
729 "".join([str(x) for x in pNew]),
730 ),
731 [sNew, pNew, kNew],
732 )
733 )
734 elif (n % sparsity == 0
Les Bell7aa69f42021-09-20 10:44:07 +0100735 # padding must not exceed the kernel size
736 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
737 # the padded shape must exceed the kernel size
738 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100739 and y < shape[1] and x < shape[2]
Les Bell7aa69f42021-09-20 10:44:07 +0100740 ):
741 arg_list.append(
742 (
743 "st{}_kern{}_pad{}".format(
744 "".join([str(x) for x in s]),
745 "".join([str(x) for x in k]),
746 "".join([str(x) for x in p]),
747 ),
748 [s, p, k],
749 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800750 )
Les Bell7aa69f42021-09-20 10:44:07 +0100751 n += 1
752
Eric Kunzee5e26762020-10-13 16:11:07 -0700753 return arg_list
754
755 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100756 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700757 arg_list = []
758
759 # Enumerate the output types here
760 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800761 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700762 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800763 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700764 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800765 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700766 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800767 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700768 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800769 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700770 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800771 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700772
773 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800774 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700775
776 return arg_list
777
778 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100779 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700780 arg_list = []
781
782 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100783 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100784 if dtype in [DType.UINT8, DType.INT8] and error_name == ErrorIf.OutputZeroPointNotZero:
785 continue
786 if inDtype == DType.UINT8 and dtype != DType.INT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100787 # The only output dtype for UINT8 is INT8, skip all other combinations
788 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100789 if inDtype != DType.INT8 and dtype == DType.UINT8 and error_name != ErrorIf.WrongOutputType:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100790 # The only input dtype for UINT8 is INT8, skip all other combinations
791 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100792 if error_name == ErrorIf.WrongOutputType and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype):
793 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100794
Kevin Cheng550ccc52021-03-03 11:21:43 -0800795 for scale32 in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100796 if error_name == ErrorIf.ScaleTrue and scale32 == False:
797 continue
798 elif error_name == ErrorIf.ScaleNotTrue and scale32 == True:
799 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800800 for double_round in [False, True]:
Matthew Haddonc2025212021-10-08 21:21:05 +0100801 if error_name == ErrorIf.ScaleNotTrue and double_round == False:
802 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -0800803 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700804
Matthew Haddonc2025212021-10-08 21:21:05 +0100805 if inDtype == DType.INT48 and scale32 and error_name != ErrorIf.ScaleTrue:
Eric Kunzee5e26762020-10-13 16:11:07 -0700806 # Illegal condition. Must be scale32=False
807 continue
Matthew Haddonc2025212021-10-08 21:21:05 +0100808 if double_round and not scale32 and error_name != ErrorIf.ScaleNotTrue:
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100809 # Illegal condition. ERROR_IF(!scale32 && double_round)
810 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700811
Kevin Cheng550ccc52021-03-03 11:21:43 -0800812 arg_list.append(
813 (
814 "out{}_sc{}_dr{}_pc{}".format(
815 DTypeNames[dtype],
816 int(scale32),
817 int(double_round),
818 int(per_channel),
819 ),
820 [dtype, scale32, double_round, per_channel],
821 )
822 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700823
824 return arg_list
825
Kevin Chengaee1fac2020-11-11 13:54:06 -0800826 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100827 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800828 arg_list = []
829
830 if dtype is DType.INT32:
831 for p in range(testGen.args.num_rand_permutations):
832
833 shift = testGen.randInt(0, 32)
834
Kevin Cheng550ccc52021-03-03 11:21:43 -0800835 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800836 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100837 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800838
839 return arg_list
840
841 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100842 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800843 arg_list = []
844
Kevin Cheng550ccc52021-03-03 11:21:43 -0800845 arg_list.append(("roundTrue", [True]))
846 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800847
848 return arg_list
849
Eric Kunzee5e26762020-10-13 16:11:07 -0700850 # Helper function for reshape. Gets some factors of a larger number.
851 @staticmethod
852 def getFactors(val, start=1):
853 factors = []
854
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100855 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700856 if (val % i) == 0:
857 factors.append(i)
858
859 return factors
860
861 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100862 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700863 arg_list = []
864
865 origShape = shapeList[0]
866
867 totalElements = 1
868 for s in origShape:
869 totalElements *= s
870
871 # This code is NOT fast. Fortunately, the numbers are fairly small.
872 factors = TosaArgGen.getFactors(totalElements)
873
874 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100875 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800876 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700877 continue
878
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100879 found = True
880 # escape_counter breaks while loop if it continues on for too long
881 escape_counter = 0
882 while found:
883 newShape = []
884 # Generate newShape ensuring it isn't a duplicate
885 remainingElements = totalElements
886 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100887 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100888 # pick rank-1 factors
889 newShape.append(shuffledFactors[0])
890 remainingElements = remainingElements // shuffledFactors[0]
891 shuffledFactors = testGen.rng.permutation(
892 TosaArgGen.getFactors(remainingElements)
893 )
894 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700895
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100896 # Toss in a -1 sometimes
897 minusOne = testGen.randInt(0, newRank * 4)
898 if minusOne < newRank:
899 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700900
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100901 # Check for duplicates
902 found = False
903 for name, other_shape in arg_list:
904 if other_shape[0] == newShape:
905 found = True
906 break
907
908 escape_counter += 1
909 if escape_counter >= 100:
910 break
911
912 if not found:
913 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700914
915 return arg_list
916
Eric Kunzee5e26762020-10-13 16:11:07 -0700917 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100918 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700919 arg_list = []
920
921 ifm_shape = shapeList[0]
922
Jeremy Johnsona6185572021-06-21 15:55:35 +0100923 # Get all permutations
924 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700925
Jeremy Johnsona6185572021-06-21 15:55:35 +0100926 # Limit to possible permutations from shape dimension or argument setting
927 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700928
Jeremy Johnsona6185572021-06-21 15:55:35 +0100929 # Get random permutation generator that uses all permutations
930 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700931
Jeremy Johnsona6185572021-06-21 15:55:35 +0100932 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -0700933 arg_list = [
934 ("perm{}".format(p), [random_permutations[p].tolist()])
935 for p in range(limit)
936 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700937 return arg_list
938
939 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100940 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700941 arg_list = []
942
943 ifm_shape = shapeList[0]
944 rank = len(ifm_shape)
945
946 for p in range(testGen.args.num_rand_permutations):
947 begin = []
948 size = []
949
Kevin Cheng550ccc52021-03-03 11:21:43 -0800950 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700951
952 for i in range(rank):
953 if ifm_shape[i] > 1:
954 begin.append(testGen.randInt(0, ifm_shape[i]))
955 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
956
957 # Invalid slice size?
958 if size[i] == 0:
959 valid = False
960 else:
961 begin.append(0)
962 size.append(1)
963
964 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800965 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700966 return arg_list
967
968 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100969 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700970 arg_list = []
971
972 ifm_shape = shapeList[0]
973 rank = len(ifm_shape)
974
975 for p in range(testGen.args.num_rand_permutations):
976
977 # Pick a few random, but small multiple values
978 # because otherwise this has a tendency to generate
979 # enormous tensors
980 multiples = []
981 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +0100982 if ifm_shape[i] > 1000:
983 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
984 multiples.append(1)
985 elif max(ifm_shape) > 1000:
986 multiples.append(2)
987 else:
988 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800989 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700990
991 return arg_list
992
993 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100994 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700995 arg_list = []
996
997 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +0100998 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700999
1000 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +01001001 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001002 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +01001003 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001004 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +01001005 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001006 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +01001007 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001008 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001009 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001010 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +01001011 elif error_name == ErrorIf.WrongInputType:
1012 # If an incorrect input type is used then we set a 'correct'
1013 # output type to avoid other errors
1014 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -07001015 else:
1016 continue
1017
1018 for outputDType in outputDTypeList:
1019 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -07001020 # Randomly generate legal output dimensions and shift
1021 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +01001022 # A output_dim of 1 will cause offset to exceed allowed range
1023 # so minimum value 2 produced below
1024 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
1025 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
1026 output_dims[0] += 1
1027 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
1028 output_dims[1] += 1
1029
Kevin Cheng77d0f762020-11-24 10:26:32 -08001030 in_center_h = (ifm_shape[1] - 1) / 2.0
1031 in_center_w = (ifm_shape[2] - 1) / 2.0
1032 out_center_h = (output_dims[0] - 1) / 2.0
1033 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -07001034
Kevin Cheng77d0f762020-11-24 10:26:32 -08001035 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
1036 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
1037 fp_offset_y = in_center_h - fp_stride_y * out_center_h
1038 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -07001039
Kevin Cheng77d0f762020-11-24 10:26:32 -08001040 if outputDType == DType.FLOAT:
1041 shift = 0
1042 stride = [0, 0]
1043 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001044 stride_fp = [fp_stride_y, fp_stride_x]
1045 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +01001046
1047 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001048 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001049 testGen,
1050 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001051 mode,
1052 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001053 shapeList,
1054 outputDType,
1055 shift,
1056 stride,
1057 stride_fp,
1058 offset,
1059 offset_fp
1060 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001061 else:
1062 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001063
Kevin Cheng550ccc52021-03-03 11:21:43 -08001064 arg_list.append(
1065 (
1066 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001067 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001068 output_dims[0],
1069 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001070 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001071 stride_fp[0],
1072 stride_fp[1],
1073 offset_fp[0],
1074 offset_fp[1],
1075 ),
1076 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001077 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001078 stride,
1079 offset,
1080 shift,
1081 stride_fp,
1082 offset_fp,
1083 output_dims,
1084 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001085 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001086 ],
1087 )
1088 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001089 else:
1090 shift = 11
1091 unit = float(1 << shift)
1092 stride_y = int(round(fp_stride_y * unit))
1093 stride_x = int(round(fp_stride_x * unit))
1094 offset_y = int(round(fp_offset_y * unit))
1095 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001096
Kevin Cheng550ccc52021-03-03 11:21:43 -08001097 while (
Matthew Haddone86fd342021-09-07 16:12:21 +01001098 stride_y >= (16 << shift)
1099 or stride_x >= (16 << shift)
1100 or offset_y >= (16 << shift)
1101 or offset_x >= (16 << shift)
1102 or offset_y <= (-16 << shift)
1103 or offset_x <= (-16 << shift)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001104 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001105 shift = shift - 1
1106 unit = float(1 << shift)
1107 stride_y = int(round(fp_stride_y * unit))
1108 stride_x = int(round(fp_stride_x * unit))
1109 offset_y = int(round(fp_offset_y * unit))
1110 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001111
Kevin Cheng550ccc52021-03-03 11:21:43 -08001112 stride = [stride_y, stride_x]
1113 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001114
1115 stride_fp = [0.0, 0.0]
1116 offset_fp = [0.0, 0.0]
1117
Matthew Haddone86fd342021-09-07 16:12:21 +01001118 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001119 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001120 testGen,
1121 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001122 mode,
1123 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001124 shapeList,
1125 outputDType,
1126 shift,
1127 stride,
1128 stride_fp,
1129 offset,
1130 offset_fp
1131 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001132 else:
1133 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001134
Kevin Cheng550ccc52021-03-03 11:21:43 -08001135 arg_list.append(
1136 (
1137 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001138 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001139 shift,
1140 output_dims[0],
1141 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001142 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001143 stride[0],
1144 stride[1],
1145 offset[0],
1146 offset[1],
1147 ),
1148 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001149 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001150 stride,
1151 offset,
1152 shift,
1153 stride_fp,
1154 offset_fp,
1155 output_dims,
1156 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001157 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001158 ],
1159 )
1160 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001161
1162 return arg_list
1163
Matthew Haddon1c00b712021-10-01 15:51:03 +01001164 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001165 # CondIf generates the condition values here.
1166 # Convert to tensors in the build function, along with the
1167 # then and else blocks
1168 arg_list = []
1169
1170 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001171 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001172
1173 return arg_list
1174
Matthew Haddon1c00b712021-10-01 15:51:03 +01001175 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001176 # While loop: 0 iterations, 1, more than 1
1177 arg_list = []
1178
1179 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001180 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001181
1182 return arg_list
1183
Matthew Haddone86fd342021-09-07 16:12:21 +01001184class TosaErrorIfArgGen:
1185
1186 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001187 def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
Matthew Haddone86fd342021-09-07 16:12:21 +01001188
1189 if outputDType == DType.FLOAT:
1190 if error_name == ErrorIf.StrideSmallerEqualZero:
1191 stride_fp = testGen.rng.random(size=[2]) - 2
1192 elif error_name == ErrorIf.ShiftNotZero:
1193 shift = testGen.rng.integers(1, 5)
1194 elif error_name == ErrorIf.StrideLargerDimension:
1195 shape = shapeList[0]
1196 transform_height = testGen.rng.choice([False, True])
1197 if transform_height:
1198 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1199 else:
1200 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1201 else:
1202 if error_name == ErrorIf.StrideSmallerEqualZero:
1203 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1204 elif error_name == ErrorIf.ShiftSmallerOne:
1205 shift = testGen.rng.integers(-3, 1)
1206 if shift <= 0:
1207 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1208 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1209 else:
1210 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1211 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1212 elif error_name == ErrorIf.ShiftLargerEleven:
1213 shift = np.int16(testGen.rng.integers(12, 15))
1214 elif error_name == ErrorIf.StrideLargerDimension:
1215 shape = shapeList[0]
1216 transform_height = testGen.rng.choice([False, True])
1217 if transform_height:
1218 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1219 else:
1220 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1221 elif error_name == ErrorIf.StrideLargerEqualMax:
1222 stride = [(16 << shift) + 1, (16 << shift) + 1]
1223 elif error_name == ErrorIf.OffsetLargerEqualMax:
1224 offset = [(16 << shift) + 1, (16 << shift) + 1]
1225 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1226 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1227
Matthew Haddon1c00b712021-10-01 15:51:03 +01001228
Matthew Haddon848efb42021-09-09 12:30:53 +01001229 if error_name == ErrorIf.WrongOutputType:
1230 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1231 incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1232 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1233 incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
1234 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1235 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
1236 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1237 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
1238 elif dtype == DType.FLOAT:
1239 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
1240 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001241
Matthew Haddon848efb42021-09-09 12:30:53 +01001242 return shift, stride, stride_fp, offset, offset_fp, outputDType
1243
1244 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001245 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
1246 if (error_name == ErrorIf.StrideSmallerOne
1247 # padding must not exceed the kernel size
1248 and pad[0] < kernel[0] and pad[1] < kernel[0] and pad[2] < kernel[1] and pad[3] < kernel[1]):
1249 wrongStride = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1250 return wrongStride, pad, kernel
1251 elif error_name == ErrorIf.PadSmallerZero:
1252 wrongPad = (testGen.rng.choice([-1, -2, -3]),
1253 testGen.rng.choice([-1, -2, -3]),
1254 testGen.rng.choice([-1, -2, -3]),
1255 testGen.rng.choice([-1, -2, -3]))
1256 return stride, wrongPad, kernel
1257 elif error_name == ErrorIf.KernelSmallerOne:
1258 wrongKernel = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1259 return stride, pad, wrongKernel
1260 elif error_name == ErrorIf.PadLargerEqualKernel:
1261 wrongPad = (testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1262 testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1263 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]),
1264 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]))
1265 return stride, wrongPad, kernel
1266 else:
1267 return None, None, None
1268
Matthew Haddonc2025212021-10-08 21:21:05 +01001269 @staticmethod
1270 def eiRescaleWrongOutputType(input_dtype, output_dtype):
1271 if input_dtype == DType.INT8:
1272 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1273 return True
1274 if input_dtype in [DType.INT16, DType.INT32]:
1275 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1276 return True
1277 elif input_dtype == DType.INT48:
1278 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1279 return True
1280 elif input_dtype == DType.UINT8:
1281 if output_dtype != DType.INT8:
1282 return True
1283 return False
1284
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001285
1286 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001287 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1288 # Mess up input/output tensors for ERROR_IF checks
1289 if error_name == "WrongInputList":
1290 add_input = testGen.rng.choice([True, False])
1291 if add_input:
1292 input_list.append('eiDummyInput')
1293 else:
1294 input_list = input_list[:-1]
1295 if error_name == "WrongOutputList":
1296 add_output = testGen.rng.choice([True, False])
1297 if add_output:
1298 output_list.append('eiDummyOutput')
1299 else:
1300 output_list = []
1301 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001302
Matthew Haddonc2025212021-10-08 21:21:05 +01001303 @staticmethod
1304 def eiRestrictDimension(shape, error_name):
1305 # Restrict dimension size if rank is large for WrongRank Error_If
1306 # This will keep the test sizes reasonably small
1307 if error_name == ErrorIf.WrongRank:
1308 if len(shape) > 4:
1309 shape[4] = 1
1310
1311 return shape
1312
Matthew Haddone86fd342021-09-07 16:12:21 +01001313class TosaErrorValidator:
1314
Matthew Haddon848efb42021-09-09 12:30:53 +01001315 @staticmethod
1316 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
1317 # Check ERROR_IF statements
1318
1319 for val_fcn in validator_fcns:
1320 val_result = val_fcn(True, **kwargs)
1321
1322 validator_name = val_result['error_name']
1323 error_result = val_result['error_result']
1324 error_reason = val_result['error_reason']
1325
1326 if error_result:
1327 if error_name == validator_name:
1328 serializer.setExpectedReturnCode(2, error_reason)
1329 else:
1330 print(f"Multiple ERROR_IF checks hit \nError required: {error_name}, Error_produced: {validator_name}")
1331 return None # Return None to delete test if wrong ERROR_IF is hit
1332 else:
1333 if error_name == validator_name:
1334 print(f"No ERROR_IF hit for {error_name}")
1335 return None
1336
1337 @staticmethod
1338 def evWrongInputType(check=False, **kwargs):
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001339 all_dtypes = {DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
Matthew Haddon848efb42021-09-09 12:30:53 +01001340
1341 # Find the unsupported input data types
1342 assert 'op' in kwargs
1343 op = kwargs['op']
1344 input_dtypes = op['types']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001345
1346 allowed_input_dtypes = {t[0] if isinstance(t, list) else t for t in input_dtypes}
1347 wrong_input_dtypes = list(all_dtypes - allowed_input_dtypes)
Matthew Haddon848efb42021-09-09 12:30:53 +01001348
1349 error_name = ErrorIf.WrongInputType
1350 param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
1351 error_result = False
1352 error_reason = "Input data type not supported for this operator"
1353
1354 if check:
1355 input_dtype = kwargs['input_dtype']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001356 if op['op'] == Op.FULLY_CONNECTED:
1357 if input_dtype not in allowed_input_dtypes:
1358 error_result = True
1359 elif input_dtype not in input_dtypes:
Matthew Haddon848efb42021-09-09 12:30:53 +01001360 error_result = True
1361
1362 info_dict = {
1363 "error_name": error_name,
1364 "error_result": error_result,
1365 "error_reason": error_reason,
1366 "param_reqs": param_reqs
1367 }
1368 return info_dict
1369
1370 @staticmethod
1371 def evWrongOutputType(check=False, **kwargs):
1372 error_name = ErrorIf.WrongOutputType
1373 param_reqs = {"rank": None, "dtype": None, "shape": None}
1374 error_result = False
1375 error_reason = "Output data type not supported for this configuration of operator"
1376
1377 if check:
1378 input_dtype = kwargs['input_dtype']
1379 output_dtype = kwargs['output_dtype']
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001380 op = kwargs['op']
Matthew Haddon848efb42021-09-09 12:30:53 +01001381
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001382 if op['op'] == Op.RESIZE:
1383 mode = kwargs['mode']
1384 if (
1385 (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
1386 (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
1387 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1388 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1389 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1390 ):
1391 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001392 elif op['op'] == Op.RESCALE:
1393 if input_dtype == DType.INT8:
1394 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1395 error_result = True
1396 if input_dtype in [DType.INT16, DType.INT32]:
1397 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1398 error_result = True
1399 elif input_dtype == DType.INT48:
1400 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1401 error_result = True
1402 elif input_dtype == DType.UINT8:
1403 if output_dtype != DType.INT8:
1404 error_result = True
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001405 elif op['op'] in [Op.FULLY_CONNECTED, Op.MATMUL]:
1406 if (
1407 (input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1408 (input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1409 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1410 ):
1411 error_result = True
1412 elif op['op'] == Op.ARGMAX:
1413 if input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] and output_dtype != DType.INT32:
1414 error_result = True
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001415 else:
1416 if output_dtype != input_dtype:
1417 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001418
1419 info_dict = {
1420 "error_name": error_name,
1421 "error_result": error_result,
1422 "error_reason": error_reason,
1423 "param_reqs": param_reqs
1424 }
1425 return info_dict
1426
1427 @staticmethod
1428 def evWrongRank(check=False, **kwargs):
1429 all_ranks = (1, 2, 3, 4, 5)
1430
1431 # Make a list of incorrect ranks
1432 assert 'op' in kwargs
1433 op = kwargs['op']
1434 rmin, rmax = op['rank']
1435 rank_range = range(rmin, rmax + 1)
1436 incorrect_ranks = list(set(all_ranks) - set(rank_range))
Matthew Haddonc2025212021-10-08 21:21:05 +01001437 # Remove small incorrect ranks to avoid index errors
1438 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
Matthew Haddon848efb42021-09-09 12:30:53 +01001439 # Set minimum incorrect rank to 3 to avoid index error
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001440 if op['op'] in [Op.RESIZE]:
Matthew Haddon848efb42021-09-09 12:30:53 +01001441 incorrect_ranks = [3, 5]
1442
1443 error_name = ErrorIf.WrongRank
1444 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1445 error_result = False
1446 error_reason = "Rank not supported for this operator"
1447
1448 if check:
1449 input_shape = kwargs['input_shape']
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001450
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001451 if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4:
Matthew Haddon848efb42021-09-09 12:30:53 +01001452 error_result = True
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001453 elif op['op'] == Op.FULLY_CONNECTED and len(input_shape) != 2:
1454 error_result = True
1455 elif op['op'] == Op.MATMUL and len(input_shape) != 3:
1456 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001457 else:
1458 if len(input_shape) not in rank_range:
1459 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001460
1461 info_dict = {
1462 "error_name": error_name,
1463 "error_result": error_result,
1464 "error_reason": error_reason,
1465 "param_reqs": param_reqs
1466 }
1467 return info_dict
1468
1469 @staticmethod
1470 def evWrongInputList(check=False, **kwargs):
1471 error_name = ErrorIf.WrongInputList
1472 param_reqs = {"rank": None, "dtype": None, "shape": None}
1473 error_result = False
1474 error_reason = "Op input list does not match expected input"
1475
1476 if check:
1477 op = kwargs['op']
1478 input_list = kwargs['input_list']
1479 num_operands = kwargs['num_operands']
1480 if len(input_list) != num_operands:
1481 error_result = True
1482
1483 info_dict = {
1484 "error_name": error_name,
1485 "error_result": error_result,
1486 "error_reason": error_reason,
1487 "param_reqs": param_reqs
1488 }
1489 return info_dict
1490
1491 @staticmethod
1492 def evWrongOutputList(check=False, **kwargs):
1493 error_name = ErrorIf.WrongOutputList
1494 param_reqs = {"rank": None, "dtype": None, "shape": None}
1495 error_result = False
1496 error_reason = "Op output list does not match expected output"
1497
1498 if check:
1499 output_list = kwargs['output_list']
1500 # Note this will be incorrect if an operator returns more than one output
1501 if len(output_list) != 1:
1502 error_result = True
1503
1504 info_dict = {
1505 "error_name": error_name,
1506 "error_result": error_result,
1507 "error_reason": error_reason,
1508 "param_reqs": param_reqs
1509 }
1510 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01001511
1512 @staticmethod
1513 def evMaxDimExceeded(check=False, **kwargs):
1514 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01001515 param_reqs = {
1516 "rank": [4,4],
1517 "dtype": [DType.INT8],
1518 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
1519 }
Matthew Haddone86fd342021-09-07 16:12:21 +01001520 error_result = False
1521 error_reason = "At least one maximum dimension is larger than 16384"
1522
1523 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001524 input_shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001525 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
1526 if ((input_shape[1] > 16384) or
1527 (input_shape[2] > 16384) or
1528 (output_shape[0] > 16384) or
1529 (output_shape[1] > 16384)):
1530 error_result = True
1531
1532 info_dict = {
1533 "error_name": error_name,
1534 "error_result": error_result,
1535 "error_reason": error_reason,
1536 "param_reqs": param_reqs
1537 }
1538 return info_dict
1539
1540 @staticmethod
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001541 def evBatchMismatch(check=False, **kwargs):
1542 error_name = ErrorIf.BatchMismatch
1543 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1544 error_result = False
1545 error_reason = "Input batch size not equal to output batch size"
1546
1547 assert 'op' in kwargs
1548 op = kwargs['op']
1549 rmin, rmax = op['rank']
1550 rank_range = range(rmin, rmax + 1)
1551
1552 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001553 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001554 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1555
1556 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
1557 error_result = True
1558
1559 info_dict = {
1560 "error_name": error_name,
1561 "error_result": error_result,
1562 "error_reason": error_reason,
1563 "param_reqs": param_reqs
1564 }
1565 return info_dict
1566
1567 @staticmethod
1568 def evChannelMismatch(check=False, **kwargs):
1569 error_name = ErrorIf.ChannelMismatch
1570 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1571 error_result = False
1572 error_reason = "Input channel size not equal to output channel size"
1573
1574 assert 'op' in kwargs
1575 op = kwargs['op']
1576 rmin, rmax = op['rank']
1577 rank_range = range(rmin, rmax + 1)
1578
1579 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001580 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001581 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1582 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
1583 error_result = True
1584
1585 info_dict = {
1586 "error_name": error_name,
1587 "error_result": error_result,
1588 "error_reason": error_reason,
1589 "param_reqs": param_reqs
1590 }
1591 return info_dict
1592
1593 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001594 def evStrideSmallerEqualZero(check=False, **kwargs):
1595 error_name = ErrorIf.StrideSmallerEqualZero
1596 param_reqs = {"rank": None, "dtype": None, "shape": None}
1597 error_result = False
1598 error_reason = "Stride value smaller than or equal zero"
1599
1600 if check:
1601 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001602 output_dtype = kwargs['output_dtype']
1603 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
1604 stride = kwargs['stride'] # Work around wrong input/output type tests
1605 elif output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001606 stride = kwargs['stride_fp']
Matthew Haddon848efb42021-09-09 12:30:53 +01001607 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1608 stride = kwargs['stride_fp'] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01001609 else:
1610 stride = kwargs['stride']
1611
1612 if min(stride) <= 0:
1613 error_result = True
1614
1615 info_dict = {
1616 "error_name": error_name,
1617 "error_result": error_result,
1618 "error_reason": error_reason,
1619 "param_reqs": param_reqs
1620 }
1621 return info_dict
1622
1623 @staticmethod
1624 def evStrideLargerEqualMax(check=False, **kwargs):
1625 error_name = ErrorIf.StrideLargerEqualMax
1626 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1627 error_result = False
1628 error_reason = "Stride value larger than or equal to maximum value"
1629
1630 if check:
1631 shift = kwargs['shift']
1632 input_dtype = kwargs['input_dtype']
1633 stride = kwargs['stride']
1634 if input_dtype in [DType.INT8, DType.INT16]:
1635 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1636 error_result = True
1637 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1638 error_result = True
1639
1640 info_dict = {
1641 "error_name": error_name,
1642 "error_result": error_result,
1643 "error_reason": error_reason,
1644 "param_reqs": param_reqs
1645 }
1646 return info_dict
1647
1648
1649 @staticmethod
1650 def evStrideLargerDimension(check=False, **kwargs):
1651 error_name = ErrorIf.StrideLargerDimension
1652 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1653 error_result = False
1654 error_reason = "Stride value larger than or equal to H/W dimension"
1655
1656 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001657 shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001658 input_dtype = kwargs['input_dtype']
1659 stride = kwargs['stride_fp']
1660
1661 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1662 error_result = True
1663
1664 info_dict = {
1665 "error_name": error_name,
1666 "error_result": error_result,
1667 "error_reason": error_reason,
1668 "param_reqs": param_reqs
1669 }
1670 return info_dict
1671
1672
1673 @staticmethod
1674 def evOffsetSmallerEqualMin(check=False, **kwargs):
1675 error_name = ErrorIf.OffsetSmallerEqualMin
1676 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1677 error_result = False
1678 error_reason = "Offset value smaller than or equal to minimum value"
1679
1680 if check:
1681 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001682 output_dtype = kwargs['output_dtype']
1683 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001684 offset = kwargs['offset_fp']
1685 else:
1686 offset = kwargs['offset']
1687
1688 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
1689 error_result = True
1690 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
1691 error_result = True
1692
1693 info_dict = {
1694 "error_name": error_name,
1695 "error_result": error_result,
1696 "error_reason": error_reason,
1697 "param_reqs": param_reqs
1698 }
1699 return info_dict
1700
1701 @staticmethod
1702 def evOffsetLargerEqualMax(check=False, **kwargs):
1703 error_name = ErrorIf.OffsetLargerEqualMax
1704 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1705 error_result = False
1706 error_reason = "Offset value larger than or equal to maximum value"
1707
1708 if check:
1709 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001710 output_dtype = kwargs['output_dtype']
1711 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001712 offset = kwargs['offset_fp']
1713 else:
1714 offset = kwargs['offset']
1715
1716 if shift >= 0:
1717 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
1718 error_result = True
1719
1720 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
1721 error_result = True
1722 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
1723 error_result = True
1724
1725 info_dict = {
1726 "error_name": error_name,
1727 "error_result": error_result,
1728 "error_reason": error_reason,
1729 "param_reqs": param_reqs
1730 }
1731 return info_dict
1732
1733 @staticmethod
1734 def evShiftNotZero(check=False, **kwargs):
1735 error_name = ErrorIf.ShiftNotZero
1736 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1737 error_result = False
1738 error_reason = "Shift value must be zero for float input"
1739
1740 if check:
1741 shift = kwargs['shift']
1742 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001743 output_dtype = kwargs['output_dtype']
1744 if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
Matthew Haddone86fd342021-09-07 16:12:21 +01001745 error_result = True
1746
1747 info_dict = {
1748 "error_name": error_name,
1749 "error_result": error_result,
1750 "error_reason": error_reason,
1751 "param_reqs": param_reqs
1752 }
1753 return info_dict
1754
1755
1756 @staticmethod
1757 def evShiftSmallerOne(check=False, **kwargs):
1758 error_name = ErrorIf.ShiftSmallerOne
1759 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1760 error_result = False
1761 error_reason = "Shift value smaller than one"
1762
1763 if check:
1764 shift = kwargs['shift']
1765 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001766 output_dtype = kwargs['output_dtype']
1767 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001768 error_result = True
1769
1770 info_dict = {
1771 "error_name": error_name,
1772 "error_result": error_result,
1773 "error_reason": error_reason,
1774 "param_reqs": param_reqs
1775 }
1776 return info_dict
1777
1778 @staticmethod
1779 def evShiftLargerEleven(check=False, **kwargs):
1780 error_name = ErrorIf.ShiftLargerEleven
1781 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1782 error_result = False
1783 error_reason = "Shift value larger than eleven"
1784
1785 if check:
1786 shift = kwargs['shift']
1787 if shift > 11:
1788 error_result = True
1789
1790 info_dict = {
1791 "error_name": error_name,
1792 "error_result": error_result,
1793 "error_reason": error_reason,
1794 "param_reqs": param_reqs
1795 }
1796 return info_dict
1797
1798
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001799 @staticmethod
1800 def evRankMismatch(check=False, **kwargs):
1801 error_name = ErrorIf.RankMismatch
1802 param_reqs = {"rank": None, "dtype": None, "shape": None}
1803 error_result = False
1804 error_reason = "Input Rank does not match output rank"
1805
1806 if check:
1807 input1_shape = kwargs['input1'].shape
1808 input2_shape = kwargs['input2'].shape
1809 output_shape = kwargs['result_tensor'].shape
1810 if (len(input1_shape) != len(output_shape)) or (len(input2_shape) != len(output_shape)):
1811 error_result = True
1812
1813 info_dict = {
1814 "error_name": error_name,
1815 "error_result": error_result,
1816 "error_reason": error_reason,
1817 "param_reqs": param_reqs
1818 }
1819 return info_dict
1820
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001821 @staticmethod
1822 def evInputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001823 op = kwargs['op']
1824 inputDtypes = op['types'].copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001825 # If inputDtypes is a list then only the first two elements are INT8 inputs
1826 if isinstance(inputDtypes, list):
1827 inputDtypes = inputDtypes[2:]
1828
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001829 if DType.INT8 in inputDtypes:
1830 inputDtypes.remove(DType.INT8)
1831 if DType.UINT8 in inputDtypes:
1832 inputDtypes.remove(DType.UINT8)
1833
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001834 error_name = ErrorIf.InputZeroPointNotZero
1835 param_reqs = {
1836 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001837 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001838 "shape": None
1839 }
1840 error_result = False
1841 error_reason = "Input DType not INT8 and zero point not 0"
1842
1843 if check:
1844 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01001845 if isinstance(kwargs['qinfo'], tuple):
1846 qinfo = kwargs['qinfo']
1847 input_zero_point = qinfo[0]
1848 else:
1849 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1850 qinfo = kwargs['qinfo'].ints
1851 input_zero_point = qinfo[0][1]
1852
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001853 if op['op'] == Op.MATMUL:
1854 input1_dtype = kwargs['input_dtype']
1855 input2_dtype = kwargs['input2_dtype']
1856 qinfo = kwargs['qinfo'].ints
1857 input1_zero_point = qinfo[0][1]
1858 input2_zero_point = qinfo[1][1]
1859 if (input1_dtype != DType.INT8 and input1_zero_point != 0) or (input2_dtype != DType.INT8 and input2_zero_point != 0):
1860 error_result = True
1861 else:
1862 if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
1863 error_result = True
1864
1865 info_dict = {
1866 "error_name": error_name,
1867 "error_result": error_result,
1868 "error_reason": error_reason,
1869 "param_reqs": param_reqs
1870 }
1871 return info_dict
1872
1873
1874 @staticmethod
1875 def evWeightZeroPointNotZero(check=False, **kwargs):
1876 op = kwargs['op']
1877
1878 # exclude inputs with INT8 weights
1879 inputDtypes = [t for t in op['types']
1880 if not isinstance(t, list) or t[1] != DType.INT8]
1881
1882 error_name = ErrorIf.WeightZeroPointNotZero
1883 param_reqs = {
1884 "rank": None,
1885 "dtype": inputDtypes,
1886 "shape": None
1887 }
1888 error_result = False
1889 error_reason = "Weight DType not INT8 and zero point not 0"
1890
1891 if check:
1892 weight_dtype = kwargs['weight_dtype']
1893 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
1894 qinfo = kwargs['qinfo'].ints
1895 weight_zero_point = qinfo[1][1]
1896 if weight_dtype != DType.INT8 and weight_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001897 error_result = True
1898
1899 info_dict = {
1900 "error_name": error_name,
1901 "error_result": error_result,
1902 "error_reason": error_reason,
1903 "param_reqs": param_reqs
1904 }
1905 return info_dict
1906
1907
1908 @staticmethod
1909 def evOutputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001910 op = kwargs['op']
1911 inputDtypes = op['types'].copy()
1912 if DType.INT8 in inputDtypes:
1913 inputDtypes.remove(DType.INT8)
1914 if DType.UINT8 in inputDtypes:
1915 inputDtypes.remove(DType.UINT8)
1916
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001917 error_name = ErrorIf.OutputZeroPointNotZero
1918 param_reqs = {
1919 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001920 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001921 "shape": None
1922 }
1923 error_result = False
1924 error_reason = "Output DType not INT8 and zero point not 0"
1925
1926 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001927 input_dtype = kwargs['input_dtype']
Matthew Haddonc2025212021-10-08 21:21:05 +01001928 output_dtype = kwargs['output_dtype']
1929 if isinstance(kwargs['qinfo'], tuple):
1930 qinfo = kwargs['qinfo']
1931 output_zero_point = qinfo[1]
1932 else:
1933 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1934 qinfo = kwargs['qinfo'].ints
1935 output_zero_point = qinfo[1][1]
1936 if op['op'] == Op.AVG_POOL2D:
1937 if input_dtype != DType.INT8 and output_zero_point != 0:
1938 error_result = True
1939 elif output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001940 error_result = True
1941
1942 info_dict = {
1943 "error_name": error_name,
1944 "error_result": error_result,
1945 "error_reason": error_reason,
1946 "param_reqs": param_reqs
1947 }
1948 return info_dict
1949
Matthew Haddond6ce7252021-09-29 15:35:44 +01001950 @staticmethod
1951 def evAxisSmallerZero(check=False, **kwargs):
1952 error_name = ErrorIf.AxisSmallerZero
1953 param_reqs = {"rank": None, "dtype": None, "shape": None}
1954 error_result = False
1955 error_reason = "Axis smaller than zero"
1956
1957 if check:
1958 axis = kwargs['axis']
1959 if axis < 0:
1960 error_result = True
1961
1962 info_dict = {
1963 "error_name": error_name,
1964 "error_result": error_result,
1965 "error_reason": error_reason,
1966 "param_reqs": param_reqs
1967 }
1968 return info_dict
1969
1970
1971 @staticmethod
1972 def evAxisLargerRank(check=False, **kwargs):
1973 error_name = ErrorIf.AxisLargerRank
1974 param_reqs = {"rank": None, "dtype": None, "shape": None}
1975 error_result = False
1976 error_reason = "Axis larger than rank"
1977
1978 if check:
1979 axis = kwargs['axis']
1980 shape = kwargs['input_shape']
1981 if axis > len(shape):
1982 error_result = True
1983
1984 info_dict = {
1985 "error_name": error_name,
1986 "error_result": error_result,
1987 "error_reason": error_reason,
1988 "param_reqs": param_reqs
1989 }
1990 return info_dict
1991
1992
1993 @staticmethod
1994 def evShapeOfAxisNotOne(check=False, **kwargs):
1995 error_name = ErrorIf.ShapeOfAxisNotOne
1996 param_reqs = {"rank": None, "dtype": None, "shape": None}
1997 error_result = False
1998 error_reason = "shape[axis] is not equal to 1"
1999
2000 if check:
2001 axis = kwargs['axis']
2002 shape = kwargs['output_shape']
2003 if (0 <= axis < len(shape)) and shape[axis] != 1:
2004 error_result = True
2005
2006 info_dict = {
2007 "error_name": error_name,
2008 "error_result": error_result,
2009 "error_reason": error_reason,
2010 "param_reqs": param_reqs
2011 }
2012 return info_dict
2013
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002014
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002015 @staticmethod
2016 def evPadSmallerZero(check=False, **kwargs):
2017 error_name = ErrorIf.PadSmallerZero
2018 param_reqs = {"rank": None, "dtype": None, "shape": None}
2019 error_result = False
2020 error_reason = "At least one pad is smaller than zero"
2021
2022 if check:
2023 pad = kwargs['pad']
2024 if min(pad) < 0:
2025 error_result = True
2026
2027 info_dict = {
2028 "error_name": error_name,
2029 "error_result": error_result,
2030 "error_reason": error_reason,
2031 "param_reqs": param_reqs
2032 }
2033 return info_dict
2034
2035
2036 @staticmethod
2037 def evPadLargerEqualKernel(check=False, **kwargs):
2038 error_name = ErrorIf.PadLargerEqualKernel
2039 param_reqs = {"rank": None, "dtype": None, "shape": None}
2040 error_result = False
2041 error_reason = "At least one pad is larger than kernel dimension"
2042
2043 if check:
2044 pad = kwargs['pad']
2045 kernel = kwargs['kernel']
2046 if min(pad) > 0 and min(kernel) > 1:
2047 if pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]:
2048 error_result = True
2049
2050 info_dict = {
2051 "error_name": error_name,
2052 "error_result": error_result,
2053 "error_reason": error_reason,
2054 "param_reqs": param_reqs
2055 }
2056 return info_dict
2057
2058 @staticmethod
2059 def evPoolingOutputShapeMismatch(check=False, **kwargs):
2060 error_name = ErrorIf.PoolingOutputShapeMismatch
2061 param_reqs = {"rank": None, "dtype": None, "shape": None}
2062 error_result = False
2063 error_reason = "Mismatch between output shape provided and expected output shape"
2064
2065 if check:
2066 pad = kwargs['pad']
2067 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
2068
2069 kernel = kwargs['kernel']
2070 kernel_y, kernel_x = kernel[0], kernel[1]
2071
2072 input_shape = kwargs['input_shape']
2073 IH, IW = input_shape[1], input_shape[2]
2074
2075 output_shape = kwargs['output_shape']
2076 OH, OW = output_shape[1], output_shape[2]
2077
2078 stride = kwargs['stride']
2079 stride_y, stride_x = stride[0], stride[1]
2080
2081 # calculate correct height, width dimensions
2082 if stride_x != 0 and stride_y != 0:
2083 y_correct = (IH + pad_top + pad_bottom + stride_y - kernel_y) // stride_y
2084 x_correct = (IW + pad_left + pad_right + stride_x - kernel_x) // stride_x
2085
2086 # ensure parameters are valid
2087 params_valid = (min(kernel) >= 1 and min(stride) >= 1 and min(pad) >= 0
2088 and not (pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]))
2089
2090 if params_valid and (OH != y_correct or OW != x_correct):
2091 error_result = True
2092
2093 info_dict = {
2094 "error_name": error_name,
2095 "error_result": error_result,
2096 "error_reason": error_reason,
2097 "param_reqs": param_reqs
2098 }
2099 return info_dict
2100
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002101 @staticmethod
2102 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
2103 error_name = ErrorIf.ArgmaxOutputShapeMismatch
2104 param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
2105 error_result = False
2106 error_reason = "Mismatch between output shape provided and expected output shape"
2107
2108 if check:
2109 output_shape = kwargs['output_shape']
2110 input_shape = kwargs['input_shape']
2111 axis = kwargs['axis']
2112
2113 dimension_match = True
2114 axis_shift = 0
2115
2116 # Check that rank is correct before trying to check dimensions
2117 if (len(input_shape) - 1) == len(output_shape):
2118 for i in range(len(input_shape)):
2119 if i == axis:
2120 axis_shift = 1
2121 continue
2122 if input_shape[i] != output_shape[i - axis_shift]:
2123 dimension_match = False
2124
2125 if not dimension_match:
2126 error_result = True
2127
2128 info_dict = {
2129 "error_name": error_name,
2130 "error_result": error_result,
2131 "error_reason": error_reason,
2132 "param_reqs": param_reqs
2133 }
2134 return info_dict
2135
2136 @staticmethod
2137 def evArgmaxOutputRankMismatch(check=False, **kwargs):
2138 error_name = ErrorIf.ArgmaxOutputRankMismatch
2139 param_reqs = {"rank": None, "dtype": None, "shape": None}
2140 error_result = False
2141 error_reason = "Mismatch between output shape provided and expected output shape"
2142
2143 if check:
2144 output_shape = kwargs['output_shape']
2145 input_shape = kwargs['input_shape']
2146 axis = kwargs['axis']
2147 valid_params = axis >= 0 and axis < len(input_shape)
2148
2149 if valid_params and (len(input_shape) - 1) != len(output_shape):
2150 error_result = True
2151
2152 info_dict = {
2153 "error_name": error_name,
2154 "error_result": error_result,
2155 "error_reason": error_reason,
2156 "param_reqs": param_reqs
2157 }
2158 return info_dict
2159
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002160
2161 @staticmethod
2162 def evKernelSmallerOne(check=False, **kwargs):
2163 error_name = ErrorIf.KernelSmallerOne
2164 param_reqs = {"rank": None, "dtype": None, "shape": None}
2165 error_result = False
2166 error_reason = "At least one kernel dimension is smaller than zero"
2167
2168 if check:
2169 kernel = kwargs['kernel']
2170 if min(kernel) < 1:
2171 error_result = True
2172
2173 info_dict = {
2174 "error_name": error_name,
2175 "error_result": error_result,
2176 "error_reason": error_reason,
2177 "param_reqs": param_reqs
2178 }
2179 return info_dict
2180
2181 @staticmethod
2182 def evStrideSmallerOne(check=False, **kwargs):
2183 error_name = ErrorIf.StrideSmallerOne
2184 param_reqs = {"rank": None, "dtype": None, "shape": None}
2185 error_result = False
2186 error_reason = "At least one stride dimension is smaller than zero"
2187
2188 if check:
2189 stride = kwargs['stride']
2190 if min(stride) < 1:
2191 error_result = True
2192
2193 info_dict = {
2194 "error_name": error_name,
2195 "error_result": error_result,
2196 "error_reason": error_reason,
2197 "param_reqs": param_reqs
2198 }
2199 return info_dict
2200
Matthew Haddonc2025212021-10-08 21:21:05 +01002201 @staticmethod
2202 def evScaleTrue(check=False, **kwargs):
2203 error_name = ErrorIf.ScaleTrue
2204 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
2205 error_result = False
2206 error_reason = "Scale set to true but input type is INT48"
2207
2208 if check:
2209 input_dtype = kwargs['input_dtype']
2210 scale32 = kwargs['scale32']
2211 if scale32 and input_dtype == DType.INT48:
2212 error_result = True
2213
2214 info_dict = {
2215 "error_name": error_name,
2216 "error_result": error_result,
2217 "error_reason": error_reason,
2218 "param_reqs": param_reqs
2219 }
2220 return info_dict
2221
2222 @staticmethod
2223 def evScaleNotTrue(check=False, **kwargs):
2224 error_name = ErrorIf.ScaleNotTrue
2225 param_reqs = {"rank": None, "dtype": None, "shape": None}
2226 error_result = False
2227 error_reason = "Scale set to false but double round set to true"
2228
2229 if check:
2230 scale32 = kwargs['scale32']
2231 double_round = kwargs['double_round']
2232 if not scale32 and double_round:
2233 error_result = True
2234
2235 info_dict = {
2236 "error_name": error_name,
2237 "error_result": error_result,
2238 "error_reason": error_reason,
2239 "param_reqs": param_reqs
2240 }
2241 return info_dict
2242
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002243
2244
Matthew Haddonb724efc2021-08-25 16:40:29 +01002245class TosaInvalidValidator:
2246
2247 @staticmethod
2248 def ivWrongDataTypeOrModeResize(**kwargs):
2249 input_dtype = kwargs["input_dtype"]
2250 args = kwargs["args"]
2251 mode = args[0]
2252 stride = args[1]
2253 stride_fp = args[4]
2254 output_dtype = args[8]
2255
2256 if mode == ResizeMode.BILINEAR:
2257 # Invalid output data type / Invalid input datatype
2258 return (
2259 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
2260 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
2261 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
2262 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2263 )
2264 elif mode == ResizeMode.NEAREST:
2265 # Invalid output data type / Invalid input datatype
2266 return (
2267 (input_dtype != output_dtype) or
2268 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2269 )
2270 else:
2271 # Invalid resize mode
2272 return True
2273
2274 @staticmethod
2275 def ivBadStride(**kwargs):
2276 input_dtype = kwargs["input_dtype"]
2277 args = kwargs["args"]
2278 stride_x = args[1][0]
2279 stride_y = args[1][1]
2280 stride_fp_x = args[4][0]
2281 stride_fp_y = args[4][1]
2282
2283 if input_dtype == DType.FLOAT:
2284 if stride_fp_x <= 0 or stride_fp_y <= 0:
2285 # Negative or zero stride
2286 return True
2287 else:
2288 if stride_x <= 0 or stride_y <= 0:
2289 # Negative or zero stride
2290 return True
2291 return False
2292
2293
Matthew Haddonb724efc2021-08-25 16:40:29 +01002294 @staticmethod
2295 def ivHeightWidthSmallerZero(**kwargs):
2296 opName = kwargs['opName']
2297
2298 inputShapes = kwargs['shapeList']
2299 input = inputShapes[0]
2300 if not opName.endswith("pool2d"):
2301 filter = inputShapes[1]
2302
2303 args = kwargs['args']
2304 strides = args[0]
2305 padding = args[1]
2306 dilations = args[2]
2307 if opName.endswith("pool2d"):
2308 kernel = args[2]
2309
2310 if opName.startswith('conv2d'):
2311 h = (
2312 input[1]
2313 - filter[1]
2314 - (filter[1] - 1) * (dilations[0] - 1)
2315 + padding[0]
2316 + padding[1]
2317 ) // strides[0] + 1
2318
2319 w = (
2320 input[2]
2321 - filter[2]
2322 - (filter[2] - 1) * (dilations[1] - 1)
2323 + padding[2]
2324 + padding[3]
2325 ) // strides[1] + 1
2326 elif opName.startswith("depthwise_conv2d"):
2327 h = (
2328 input[1]
2329 - filter[0]
2330 - (filter[0] - 1) * (dilations[0] - 1)
2331 + padding[0]
2332 + padding[1]
2333 ) // strides[0] + 1
2334
2335 w = (
2336 input[2]
2337 - filter[1]
2338 - (filter[1] - 1) * (dilations[1] - 1)
2339 + padding[2]
2340 + padding[3]
2341 ) // strides[1] + 1
2342 elif opName.endswith("pool2d"):
2343 h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
2344 w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
2345 else:
2346 assert False, "Unrecognized Op"
2347
2348 if h <= 0 or w <= 0:
2349 # Invalid parameter combination
2350 return True
2351 return False
2352
2353 @staticmethod
2354 def ivNonPositiveOutputShape(**kwargs):
2355 args = kwargs['args']
2356 output_shape = args[3]
2357 if output_shape[1] <= 0 or output_shape[2] <= 0:
2358 # Negative output shape
2359 return True
2360 return False
2361
2362
Kevin Cheng550ccc52021-03-03 11:21:43 -08002363
Eric Kunzee5e26762020-10-13 16:11:07 -07002364class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002365 # Maximum rank of tensor supported by test generator.
2366 TOSA_TENSOR_MAX_RANK = 6
2367
Eric Kunzee5e26762020-10-13 16:11:07 -07002368 def __init__(self, args):
2369 self.args = args
2370 self.basePath = args.output_dir
2371 self.random_seed = args.random_seed
2372 self.ser = None
2373 self.rng = np.random.default_rng(self.random_seed)
2374 self.createDynamicOpLists()
2375 self.initOpListDefaults()
2376 self.quantGen = TosaQuantGen()
2377 # Force makeShape to do a specific starting shape
2378 self.targetted_shape = None
2379
2380 def createSerializer(self, opName, testPath):
2381 self.testPath = os.path.join(opName, testPath)
2382
2383 fullPath = os.path.join(self.basePath, self.testPath)
2384 os.makedirs(fullPath, exist_ok=True)
2385 self.ser = ts.TosaSerializer(fullPath)
2386
2387 def getSerializer(self):
2388 return self.ser
2389
2390 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002391 with open(
2392 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
2393 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07002394 fd.write(self.ser.serialize())
2395
Kevin Cheng550ccc52021-03-03 11:21:43 -08002396 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
2397 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07002398
Matthew Haddon74567092021-07-16 15:38:20 +01002399 def resetRNG(self, seed=None):
2400 if seed == None:
2401 seed = self.random_seed + 1
2402 self.rng = np.random.default_rng(seed)
2403
Eric Kunzee5e26762020-10-13 16:11:07 -07002404 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07002405 if dtype == DType.BOOL:
2406 np_dt = np.bool
2407 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07002408 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07002409 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07002410 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002411 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002412 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
2413 elif dtype == DType.UINT8:
2414 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002415 elif dtype == DType.INT16:
2416 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
2417 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002418 return np.int32(
2419 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
2420 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002421 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002422 return np.int64(
2423 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
2424 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002425 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002426 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002427 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002428 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002429
Kevin Cheng989cb052021-04-28 16:29:44 -07002430 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07002431 placeholders = []
2432
Kevin Cheng989cb052021-04-28 16:29:44 -07002433 assert len(shape_list) == len(dtype_list)
2434
2435 for idx, shape in enumerate(shape_list):
2436 arr = self.getRandTensor(shape, dtype_list[idx])
2437 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002438
2439 return placeholders
2440
Kevin Cheng989cb052021-04-28 16:29:44 -07002441 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07002442 consts = []
2443
Kevin Cheng989cb052021-04-28 16:29:44 -07002444 assert len(shape_list) == len(dtype_list)
2445
2446 for idx, shape in enumerate(shape_list):
2447 arr = self.getRandTensor(shape, dtype_list[idx])
2448 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002449
2450 return consts
2451
2452 def makeShape(self, rank):
2453 if self.targetted_shape:
2454 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002455 return np.int32(
2456 self.rng.integers(
2457 low=self.args.tensor_shape_range[0],
2458 high=self.args.tensor_shape_range[1],
2459 size=rank,
2460 )
2461 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002462
2463 def setTargetShape(self, shape):
2464 self.targetted_shape = shape
2465
2466 def randInt(self, low=0, high=256):
2467 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
2468
2469 def getRandNumberDType(self, dtype):
2470 if dtype == DType.FLOAT:
2471 return self.rng.random()
2472 elif dtype == DType.BOOL:
2473 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07002474 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07002475 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07002476 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07002477 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002478 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07002479 elif dtype == DType.INT16:
2480 low, high = (-32768, 32768)
2481 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002482 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07002483 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002484 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07002485 # Special size
2486 return np.int64(self.rng.integers(low, high, size=1))[0]
2487 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002488 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002489
2490 return np.int32(self.rng.integers(low, high, size=1))[0]
2491
2492 def shapeStr(self, shape):
2493
2494 sStr = []
2495 # Convert to strings
2496 for i in shape:
2497 sStr.append(str(i))
2498
Kevin Cheng550ccc52021-03-03 11:21:43 -08002499 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002500
2501 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07002502 if isinstance(t, list):
2503 assert len(t) >= 2
2504 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002505 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002506 if t == DType.BOOL:
2507 return "b"
2508 elif t == DType.INT4:
2509 return "i4"
2510 elif t == DType.INT8:
2511 return "i8"
2512 elif t == DType.UINT8:
2513 return "u8"
2514 elif t == DType.INT16:
2515 return "i16"
2516 elif t == DType.INT32:
2517 return "i32"
2518 elif t == DType.INT48:
2519 return "i48"
2520 elif t == DType.FLOAT:
2521 return "float"
2522 else:
2523 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002524
2525 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002526 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08002527 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07002528 return 4
2529 elif t == DType.INT8:
2530 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08002531 elif t == DType.UINT8:
2532 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07002533 elif t == DType.INT16:
2534 return 16
2535 elif t == DType.INT32:
2536 return 32
2537 elif t == DType.INT48:
2538 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +01002539 elif t == DType.FLOAT:
2540 return 32
2541 elif t == DType.BOOL:
2542 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002543 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002544 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002545
2546 # Argument generators
2547 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
2548 # Where the string descriptor is used to generate the test name and
2549 # The build_fcn_arg_list is expanded and passed to the operator test
2550 # build function
2551
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002552 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
2553 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
2554
Matthew Haddon848efb42021-09-09 12:30:53 +01002555 # build_placeholder returns an int, ABS/other ops does not
2556 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002557 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
2558 return result_tens
2559 elif op['op'] == Op.IDENTITY:
2560 self.ser.addOperator(op['op'], a.name, result_tens.name, None, qinfo)
2561 return result_tens
2562
2563 # Ensure new output type has correct qinfo
2564 if error_name == ErrorIf.WrongOutputType:
2565 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
2566 qinfo = ts.TosaSerializerQuantInfo()
2567 qinfo.UnaryQuantInfo(
2568 TosaQuantGen.getQinfo(self, a.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
2569 )
2570
2571 # Invalidate Input/Output list for error if checks.
2572 input_list = [a.name]
2573 output_list = [result_tens.name]
2574 pCount, cCount = op["operands"]
2575 num_operands = pCount + cCount
2576 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2577
2578 TosaErrorValidator.evValidateErrorIfs(
2579 self.ser,
2580 validator_fcns,
2581 error_name,
2582 op=op,
2583 input_dtype=a.dtype,
2584 output_dtype=result_tens.dtype,
2585 qinfo = qinfo,
2586 result_tensor = result_tens,
2587 input_list=input_list,
2588 output_list=output_list,
2589 num_operands=num_operands,
2590 )
2591
2592 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002593 return result_tens
2594
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002595 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
2596 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
2597
2598
2599 # Invalidate Input/Output list for error if checks.
2600 input_list = [a.name, b.name]
2601 output_list = [result_tens.name]
2602 pCount, cCount = op["operands"]
2603 num_operands = pCount + cCount
2604 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2605
2606 TosaErrorValidator.evValidateErrorIfs(
2607 self.ser,
2608 validator_fcns,
2609 error_name,
2610 op=op,
2611 input1 = a,
2612 input2 = b,
2613 input_dtype = a.dtype,
2614 output_dtype = result_tens.dtype,
2615 result_tensor = result_tens,
2616 input_list=input_list,
2617 output_list=output_list,
2618 num_operands=num_operands,
2619 )
2620
2621 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07002622 return result_tens
2623
2624 def build_binary_nonbroadcast(self, op, a, b):
2625 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002626 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002627 return result_tens
2628
Kevin Chengaee1fac2020-11-11 13:54:06 -08002629 def build_arithmetic_right_shift(self, op, a, b, round):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002630 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002631
2632 attr = ts.TosaSerializerAttribute()
2633 attr.ArithmeticRightShiftAttribute(round)
2634
Matthew Haddon848efb42021-09-09 12:30:53 +01002635 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002636 return result_tens
2637
2638 def build_mul(self, op, a, b, shift):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002639 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Eric Kunzee5e26762020-10-13 16:11:07 -07002640
2641 # Special for multiply:
2642 # Force the result to INT32 for INT types
2643 if a.dtype != DType.FLOAT:
2644 result_tens.setDtype(DType.INT32)
2645
Kevin Chengaee1fac2020-11-11 13:54:06 -08002646 attr = ts.TosaSerializerAttribute()
2647 attr.MulAttribute(shift)
2648
Matthew Haddon848efb42021-09-09 12:30:53 +01002649 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002650 return result_tens
2651
2652 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002653 # Constant size depending on type, random values
2654 if a.dtype == DType.INT16:
Kevin Chengacb550f2021-06-29 15:32:19 -07002655 table_dtype = DType.INT16
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002656 table_arr = self.getRandTensor([513], table_dtype)
2657 else:
2658 assert a.dtype == DType.INT8
2659 table_dtype = DType.INT8
2660 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002661
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002662 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
2663 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01002664 self.ser.addOperator(op['op'], [a.name, table_tens.name], [result_tens.name], None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002665
2666 return result_tens
2667
2668 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07002669 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002670 self.ser.addOperator(op['op'], [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002671 return result_tens
2672
2673 def build_comparison(self, op, a, b):
2674 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002675 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002676 return result_tens
2677
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002678 def build_argmax(self, op, a, axis, validator_fcns, error_name):
2679 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
2680
2681 # Invalidate Input/Output list for error if checks.
2682 input_list = [a.name]
2683 output_list = [result_tens.name]
2684 pCount, cCount = op["operands"]
2685 num_operands = pCount + cCount
2686 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2687
2688 TosaErrorValidator.evValidateErrorIfs(
2689 self.ser,
2690 validator_fcns,
2691 error_name,
2692 op=op,
2693 axis=axis,
2694 input_shape = a.shape,
2695 input_dtype = a.dtype,
2696 output_shape = result_tens.shape,
2697 output_dtype = result_tens.dtype,
2698 result_tensor = result_tens,
2699 input_list=input_list,
2700 output_list=output_list,
2701 num_operands=num_operands,
2702 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002703
2704 attr = ts.TosaSerializerAttribute()
2705 attr.AxisAttribute(axis)
2706
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002707 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002708 return result_tens
2709
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002710 def build_pool2d(self, op, input, stride, pad, kernel, validator_fcns=None, error_name=None, qinfo=None):
2711 result_tens = OutputShaper.pool2dOp(self.ser, self.rng, input, kernel, stride, pad, error_name)
2712
2713 # Ensure new output type has correct qinfo
2714 if error_name == ErrorIf.WrongInputType:
2715 if input.dtype not in [DType.INT8, DType.UINT8]:
2716 qinfo = ts.TosaSerializerQuantInfo()
2717 qinfo.UnaryQuantInfo(
2718 TosaQuantGen.getQinfo(self, input.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
2719 )
2720
2721 # Invalidate Input/Output list for error if checks.
2722 input_list = [input.name]
2723 output_list = [result_tens.name]
2724 pCount, cCount = op["operands"]
2725 num_operands = pCount + cCount
2726 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2727
2728 TosaErrorValidator.evValidateErrorIfs(
2729 self.ser,
2730 validator_fcns,
2731 error_name,
2732 op=op,
2733 input_shape=input.shape,
2734 input_dtype=input.dtype,
2735 output_shape=result_tens.shape,
2736 output_dtype=result_tens.dtype,
2737 kernel=kernel,
2738 stride=stride,
2739 pad=pad,
2740 qinfo = qinfo,
2741 result_tensor = result_tens,
2742 input_list=input_list,
2743 output_list=output_list,
2744 num_operands=num_operands,
2745 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002746
2747 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002748 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07002749
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002750 self.ser.addOperator(op['op'], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002751 return result_tens
2752
2753 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002754 assert len(padding) == 4
2755 result_tens = OutputShaper.conv2dOp(
2756 self.ser, ifm, filter, strides, padding, dilations
2757 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002758
2759 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002760 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07002761
Kevin Cheng550ccc52021-03-03 11:21:43 -08002762 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002763 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002764 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002765 return result_tens
2766
Kevin Cheng1533b852021-09-01 12:51:58 -07002767 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
2768 assert len(padding) == 6
2769 result_tens = OutputShaper.conv3dOp(
2770 self.ser, ifm, filter, strides, padding, dilations
2771 )
2772
2773 attr = ts.TosaSerializerAttribute()
2774 attr.ConvAttribute(padding, strides, dilations)
2775
2776 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002777 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng1533b852021-09-01 12:51:58 -07002778 )
2779 return result_tens
2780
Kevin Cheng550ccc52021-03-03 11:21:43 -08002781 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07002782 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002783 ):
2784 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07002785 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
2786
2787 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002788 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002789
Kevin Cheng550ccc52021-03-03 11:21:43 -08002790 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002791 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002792 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002793 return result_tens
2794
Kevin Cheng550ccc52021-03-03 11:21:43 -08002795 def build_depthwise_conv2d(
2796 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
2797 ):
2798 result_tens = OutputShaper.depthwiseConv2dOp(
2799 self.ser, ifm, filter, strides, padding, dilations
2800 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002801
2802 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002803 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07002804
Kevin Cheng550ccc52021-03-03 11:21:43 -08002805 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002806 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002807 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002808 return result_tens
2809
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002810 def build_fully_connected(self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None):
2811 result_tens = OutputShaper.fullyConnectedOp(self.ser, self.rng, ifm, filter, error_name)
2812
2813 # Invalidate Input/Output list for error if checks.
2814 input_list = [ifm.name, filter.name, bias.name]
2815 output_list = [result_tens.name]
2816 pCount, cCount = op["operands"]
2817 num_operands = pCount + cCount
2818 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2819
2820 TosaErrorValidator.evValidateErrorIfs(
2821 self.ser,
2822 validator_fcns,
2823 error_name,
2824 op=op,
2825 input_shape=ifm.shape,
2826 input_dtype=ifm.dtype,
2827 weight_dtype=filter.dtype,
2828 output_shape=result_tens.shape,
2829 output_dtype=result_tens.dtype,
2830 qinfo = qinfo,
2831 result_tensor = result_tens,
2832 input_list=input_list,
2833 output_list=output_list,
2834 num_operands=num_operands,
2835 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002836
Kevin Cheng550ccc52021-03-03 11:21:43 -08002837 self.ser.addOperator(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002838 op['op'], input_list, output_list, None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002839 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002840 return result_tens
2841
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002842 def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
2843 result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
2844
2845 # Invalidate Input/Output list for error if checks.
2846 input_list = [a.name, b.name]
2847 output_list = [result_tens.name]
2848 pCount, cCount = op["operands"]
2849 num_operands = pCount + cCount
2850 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2851
2852 TosaErrorValidator.evValidateErrorIfs(
2853 self.ser,
2854 validator_fcns,
2855 error_name,
2856 op=op,
2857 input_shape=a.shape,
2858 input_dtype=a.dtype,
2859 input2_shape=b.shape,
2860 input2_dtype=b.dtype,
2861 output_shape=result_tens.shape,
2862 output_dtype=result_tens.dtype,
2863 qinfo = qinfo,
2864 result_tensor = result_tens,
2865 input_list=input_list,
2866 output_list=output_list,
2867 num_operands=num_operands,
2868 )
2869
2870 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002871 return result_tens
2872
Matthew Haddond6ce7252021-09-29 15:35:44 +01002873 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
2874 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
2875
2876 # Invalidate Input/Output list for error if checks.
2877 input_list = [a.name]
2878 output_list = [result_tens.name]
2879 pCount, cCount = op["operands"]
2880 num_operands = pCount + cCount
2881 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2882
2883 TosaErrorValidator.evValidateErrorIfs(
2884 self.ser,
2885 validator_fcns,
2886 error_name,
2887 op=op,
2888 axis = axis,
2889 input_shape = a.shape,
2890 output_shape = result_tens.shape,
2891 input_dtype = a.dtype,
2892 output_dtype = result_tens.dtype,
2893 result_tensor = result_tens,
2894 input_list=input_list,
2895 output_list=output_list,
2896 num_operands=num_operands,
2897 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002898
2899 attr = ts.TosaSerializerAttribute()
2900 attr.AxisAttribute(axis)
2901
Matthew Haddond6ce7252021-09-29 15:35:44 +01002902 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002903 return result_tens
2904
2905 def build_clamp(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002906 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002907
2908 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01002909 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07002910
2911 if a.dtype == DType.FLOAT:
2912 attr.ClampAttribute(0, 0, min(v), max(v))
2913 else:
2914 attr.ClampAttribute(min(v), max(v), 0, 0)
2915
Matthew Haddon848efb42021-09-09 12:30:53 +01002916 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002917 return result_tens
2918
2919 def build_leaky_relu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002920 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002921 attr = ts.TosaSerializerAttribute()
2922
2923 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
2924
Matthew Haddon848efb42021-09-09 12:30:53 +01002925 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002926 return result_tens
2927
2928 # Needs an additional type/input
2929 def build_prelu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002930 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002931
Matthew Haddon848efb42021-09-09 12:30:53 +01002932 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002933 return result_tens
2934
Eric Kunzee5e26762020-10-13 16:11:07 -07002935 def build_sigmoid(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002936 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01002937 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002938 return result_tens
2939
2940 def build_tanh(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002941 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01002942 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002943 return result_tens
2944
Matthew Haddon818ab902021-07-27 09:12:49 +01002945 def build_concat(self, op, *a):
Kevin Cheng93a16282021-08-31 16:14:03 -07002946 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01002947
2948 # To store variable length list of input tensors we need to store axis along with it
2949 axis = a[-1]
2950 a = a[:-1]
2951
2952 result_tens = OutputShaper.concatOp(self.ser, axis, *a)
Eric Kunzee5e26762020-10-13 16:11:07 -07002953
2954 attr = ts.TosaSerializerAttribute()
2955 attr.AxisAttribute(axis)
2956
Matthew Haddon818ab902021-07-27 09:12:49 +01002957 input_tensor_names = []
2958 for tensor in a:
2959 input_tensor_names.append(tensor.name)
2960
Matthew Haddon848efb42021-09-09 12:30:53 +01002961 self.ser.addOperator(op['op'], input_tensor_names, [result_tens.name], attr)
2962 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002963
2964 def build_pad(self, op, a, padding, qinfo):
2965 result_tens = OutputShaper.padOp(self.ser, a, padding)
2966
2967 # Need to turn the padding array into a TOSA tensor here.
2968 # This is one of the few tensor operands that does not get
2969 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08002970 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07002971
Kevin Cheng550ccc52021-03-03 11:21:43 -08002972 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002973 op['op'], [a.name, padding_tens.name], [result_tens.name], None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002974 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002975 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002976
2977 def build_reshape(self, op, a, newShape):
2978 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
2979
2980 attr = ts.TosaSerializerAttribute()
2981 attr.ReshapeAttribute(newShape)
2982
Matthew Haddon848efb42021-09-09 12:30:53 +01002983 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002984 return result_tens
2985
2986 def build_reverse(self, op, a, axis):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002987 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002988
2989 attr = ts.TosaSerializerAttribute()
2990 attr.AxisAttribute(axis)
2991
Matthew Haddon848efb42021-09-09 12:30:53 +01002992 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002993 return result_tens
2994
2995 def build_transpose(self, op, a, perms):
2996 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
2997
Kevin Cheng550ccc52021-03-03 11:21:43 -08002998 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07002999
Matthew Haddon848efb42021-09-09 12:30:53 +01003000 self.ser.addOperator(op['op'], [a.name, perms_tens.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003001 return result_tens
3002
3003 def build_slice(self, op, a, begin, size):
3004 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
3005
3006 attr = ts.TosaSerializerAttribute()
3007 attr.SliceAttribute(begin, size)
3008
Matthew Haddon848efb42021-09-09 12:30:53 +01003009 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003010 return result_tens
3011
3012 def build_tile(self, op, a, multiples):
3013 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
3014
3015 attr = ts.TosaSerializerAttribute()
3016 attr.TileAttribute(multiples)
3017
Matthew Haddon848efb42021-09-09 12:30:53 +01003018 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003019 return result_tens
3020
Kevin Cheng77d0f762020-11-24 10:26:32 -08003021 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07003022
3023 # Create a new indicies tensor
3024 # here with data that doesn't exceed the dimensions of the values tensor
3025
Kevin Cheng550ccc52021-03-03 11:21:43 -08003026 K = values.shape[1] # K
3027 W = self.randInt(
3028 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
3029 ) # W
3030 indicies_arr = np.int32(
3031 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
3032 ) # (N, W)
3033 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003034
Kevin Cheng77d0f762020-11-24 10:26:32 -08003035 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07003036
Matthew Haddon848efb42021-09-09 12:30:53 +01003037 self.ser.addOperator(op['op'], [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003038
3039 return result_tens
3040
Kevin Cheng77d0f762020-11-24 10:26:32 -08003041 def build_scatter(self, op, values_in, input):
3042
3043 # Create a new indicies tensor
3044 # here with data that doesn't exceed the dimensions of the values_in tensor
3045
Kevin Cheng550ccc52021-03-03 11:21:43 -08003046 K = values_in.shape[1] # K
3047 W = input.shape[1] # W
3048 indicies_arr = np.int32(
3049 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
3050 ) # (N, W)
3051 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08003052
3053 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
3054
Kevin Cheng550ccc52021-03-03 11:21:43 -08003055 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003056 op['op'], [values_in.name, indicies.name, input.name], [result_tens.name]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003057 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08003058
3059 return result_tens
3060
Matthew Haddon848efb42021-09-09 12:30:53 +01003061
Kevin Cheng550ccc52021-03-03 11:21:43 -08003062 def build_resize(
3063 self,
3064 op,
3065 input,
3066 mode,
3067 stride,
3068 offset,
3069 shift,
3070 stride_fp,
3071 offset_fp,
3072 output_dims,
3073 input_dtype,
3074 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01003075 validator_fcns,
3076 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003077 ):
3078 result_tens = OutputShaper.resizeOp(
3079 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003080 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003081 input,
3082 mode,
3083 stride,
3084 offset,
3085 shift,
3086 stride_fp,
3087 offset_fp,
3088 output_dims,
3089 input_dtype,
3090 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01003091 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08003092 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003093
Matthew Haddon848efb42021-09-09 12:30:53 +01003094 # Invalidate Input/Output list for error if checks.
3095 input_list = [input.name]
3096 output_list = [result_tens.name]
3097 pCount, cCount = op["operands"]
3098 num_operands = pCount + cCount
3099 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
Matthew Haddone86fd342021-09-07 16:12:21 +01003100
Matthew Haddon848efb42021-09-09 12:30:53 +01003101 TosaErrorValidator.evValidateErrorIfs(
3102 self.ser,
3103 validator_fcns,
3104 error_name,
3105 op=op,
3106 mode=mode,
3107 shift=shift,
3108 input_dtype=input_dtype,
3109 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003110 input_shape=input.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01003111 output_shape=output_dims,
3112 offset=offset,
3113 offset_fp=offset_fp,
3114 stride=stride,
3115 stride_fp=stride_fp,
3116 input_list=input_list,
3117 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003118 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01003119 num_operands=num_operands,
3120 )
Matthew Haddone86fd342021-09-07 16:12:21 +01003121
Eric Kunzee5e26762020-10-13 16:11:07 -07003122 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08003123
Kevin Cheng550ccc52021-03-03 11:21:43 -08003124 attr.ResizeAttribute(
3125 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
3126 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003127
Matthew Haddon848efb42021-09-09 12:30:53 +01003128 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003129 return result_tens
3130
3131 def build_identityn(self, op, val, val2):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003132 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, None)
3133 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, None)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003134 self.ser.addOperator(
3135 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
3136 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003137 return result_tens
3138
Kevin Cheng17e92022021-10-01 14:33:33 -07003139 def build_const(self, op, val):
3140 self.ser.addOutputTensor(val)
3141 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07003142
3143 # Type Conversion
3144 def build_cast(self, op, val, out_dtype):
3145 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01003146 self.ser.addOperator(op['op'], [val.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003147 return result_tens
3148
Matthew Haddonc2025212021-10-08 21:21:05 +01003149 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel, validator_fcns, error_name):
Eric Kunzee5e26762020-10-13 16:11:07 -07003150 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
3151
3152 if per_channel:
3153 nc = val.shape[-1]
3154 else:
3155 nc = 1
3156
3157 in_type_width = self.typeWidth(val.dtype)
3158 out_type_width = self.typeWidth(out_dtype)
3159
Kevin Cheng3a478572021-01-22 17:21:02 -08003160 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003161 input_zp = self.randInt(-128, 128)
3162 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07003163 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003164 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07003165 in_type_width = in_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01003166 elif error_name == ErrorIf.InputZeroPointNotZero:
3167 input_zp = self.randInt(-128, 128)
3168 if input_zp == 0:
3169 input_zp = input_zp + self.rng.integers(1, 10)
3170 in_type_width = in_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003171 else:
3172 input_zp = 0
3173
Kevin Cheng3a478572021-01-22 17:21:02 -08003174 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003175 output_zp = self.randInt(-128, 128)
3176 out_type_width = out_type_width + 1
3177 elif out_dtype == DType.UINT8:
3178 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07003179 out_type_width = out_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01003180 elif error_name == ErrorIf.OutputZeroPointNotZero:
3181 output_zp = self.randInt(-128, 128)
3182 if output_zp == 0:
3183 output_zp = output_zp + self.rng.integers(1, 10)
3184 out_type_width = out_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003185 else:
3186 output_zp = 0
3187
3188 # Calculate scale based on:
3189 # scale = a *(2^output_width)/(2^input_width))
3190
3191 a = np.float32(self.rng.random(size=[nc]))
3192 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
3193
3194 if scale32:
3195 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01003196 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07003197 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
3198 else:
3199 # Cap the scaling at 2^15 - 1 for scale16
3200 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
3201
Kevin Cheng550ccc52021-03-03 11:21:43 -08003202 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07003203
3204 multiplier_arr = np.int32(np.zeros(shape=[nc]))
3205 shift_arr = np.int32(np.zeros(shape=[nc]))
3206
3207 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003208 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
3209 scale_arr[i], scale32
3210 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003211
Kevin Cheng550ccc52021-03-03 11:21:43 -08003212 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07003213
Matthew Haddonc2025212021-10-08 21:21:05 +01003214 # Invalidate Input/Output list for error if checks.
3215 input_list = [val.name]
3216 output_list = [result_tens.name]
3217 pCount, cCount = op["operands"]
3218 num_operands = pCount + cCount
3219 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
3220
3221 qinfo = (input_zp, output_zp)
3222 TosaErrorValidator.evValidateErrorIfs(
3223 self.ser,
3224 validator_fcns,
3225 error_name,
3226 op=op,
3227 input_dtype=val.dtype,
3228 output_dtype=out_dtype,
3229 input_shape=val.shape,
3230 qinfo=qinfo,
3231 scale32 = scale32,
3232 double_round = double_round,
3233 input_list=input_list,
3234 output_list=output_list,
3235 result_tensor=result_tens,
3236 num_operands=num_operands,
3237 )
3238
Eric Kunzee5e26762020-10-13 16:11:07 -07003239 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003240 attr.RescaleAttribute(
3241 input_zp,
3242 output_zp,
3243 multiplier_arr,
3244 shift_arr,
3245 scale32,
3246 double_round,
3247 per_channel,
3248 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003249
Matthew Haddonc2025212021-10-08 21:21:05 +01003250 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003251 return result_tens
3252
3253 def build_cond_if_const(self, op, then_tens, else_tens, cond):
3254 # For cond_if with constants, we're supplied with then/else tensors that we ignore
3255 # (except for the generated shap) and the condition. Build Then/Else blocks
3256 # and fill them with const nodes for the body.
3257
3258 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003259 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07003260
3261 # Make then/else tensors
3262 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01003263 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
3264 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003265
3266 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08003267 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07003268
3269 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003270 then_block = "THEN_BLOCK"
3271 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003272 attr = ts.TosaSerializerAttribute()
3273 attr.CondIfAttribute(then_block, else_block)
3274
3275 # Finally, build the op and the two blocks
Matthew Haddon848efb42021-09-09 12:30:53 +01003276 self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003277
3278 self.ser.startBasicBlock(then_block)
3279 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003280 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003281 self.ser.addOutputTensor(then_tens)
3282
3283 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003284 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003285 self.ser.addOutputTensor(else_tens)
3286
3287 return result_tens
3288
3289 def build_cond_if_binary(self, op, a, b, cond):
3290 # For cond_if with a binary op in the then/else blocks, take a and b and
3291 # alternately add or subtract them based on the condition
3292
3293 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003294 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07003295
Kevin Cheng550ccc52021-03-03 11:21:43 -08003296 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003297
3298 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003299 then_block = "THEN_BLOCK"
3300 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003301 attr = ts.TosaSerializerAttribute()
3302 attr.CondIfAttribute(then_block, else_block)
3303
3304 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08003305 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003306 op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08003307 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003308
Les Bell6040b4d2021-10-11 12:50:31 +01003309 if a.dtype in (DType.FLOAT, DType.INT32):
3310 then_op, else_op = Op.ADD, Op.SUB
3311 elif a.dtype in (DType.INT8, DType.INT16):
3312 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
3313 else:
3314 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07003315
Les Bell6040b4d2021-10-11 12:50:31 +01003316 for block, op in ((then_block, then_op), (else_block, else_op)):
3317 self.ser.startBasicBlock(block)
3318 self.ser.addInputTensor(a)
3319 self.ser.addInputTensor(b)
3320 tens = self.ser.addOutput(a.shape, a.dtype)
3321 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003322
3323 return result_tens
3324
3325 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003326 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07003327
Kevin Cheng550ccc52021-03-03 11:21:43 -08003328 cond_block = "COND_BLOCK"
3329 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07003330
3331 attr = ts.TosaSerializerAttribute()
3332 attr.WhileLoopAttribute(cond_block, body_block)
3333
3334 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08003335 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003336 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08003337 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07003338
3339 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003340 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
3341 a_out = self.ser.addIntermediate(a.shape, a.dtype)
3342 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003343
3344 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08003345 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01003346 op['op'],
Kevin Cheng550ccc52021-03-03 11:21:43 -08003347 [iter.name, a.name, acc.name],
3348 [iter_out.name, a_out.name, acc_out.name],
3349 attr,
3350 )
Kevin Chengb227ae52021-09-02 13:43:17 -07003351 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07003352
3353 # COND block (input: iter, output: cond_tens )
3354 self.ser.startBasicBlock(cond_block)
3355 self.ser.addInputTensor(iter)
3356 self.ser.addInputTensor(a)
3357 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003358 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
3359 cond_tens = self.ser.addOutput([], DType.BOOL)
3360 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003361
3362 # BODY block (input: a, acc, iter, output: a, acc, iter)
3363 # Note that local intermediate tensors need to be declared here for the outputs
3364 self.ser.startBasicBlock(body_block)
3365 self.ser.addInputTensor(iter)
3366 self.ser.addInputTensor(a)
3367 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003368 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
3369 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
3370 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003371 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
3372 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
3373 self.ser.addOutputTensor(iter_body_out)
3374 self.ser.addOutputTensor(a)
3375 self.ser.addOutputTensor(acc_body_out)
3376
3377 return acc_out
3378
Matthew Haddon1c00b712021-10-01 15:51:03 +01003379 def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
3380 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
3381 default_test_rank_range = range(1, 5)
3382 if not shapeFilter:
3383 shapeFilter = [None]
3384
3385 # Calculate the filters based on what is requested and what the operator allows
3386 rmin, rmax = op["rank"]
3387 if rankFilter is not None:
3388 cleanRankFilter = []
3389 # Ensure rankFilter values are allowed by operator
3390 for rank in rankFilter:
3391 if rank >= rmin and rank <= rmax:
3392 cleanRankFilter.append(rank)
3393 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01003394 # Ensure default behaviour is bounded by default range or by operator,
3395 # whichever is the smaller range of ranks.
3396 opRankRange = range(rmin, rmax + 1)
3397 cleanRankFilter = opRankRange if len(opRankRange) <= len(default_test_rank_range) else default_test_rank_range
Matthew Haddon1c00b712021-10-01 15:51:03 +01003398 else:
3399 cleanRankFilter = range(rmin, rmax + 1)
3400
3401 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003402
Matthew Haddon1c00b712021-10-01 15:51:03 +01003403 if dtypeFilter is not None:
3404 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01003405 # Create list of operator dtypes filtered by requested dtypes
3406 for dtype in dtypes:
3407 if dtype in dtypeFilter or (isinstance(dtype, list) and dtype[0] in dtypeFilter):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003408 cleanDtypeFilter.append(dtype)
3409 else:
3410 cleanDtypeFilter = dtypes
3411
3412 if testType == 'positive':
3413 filterDict = {
3414 'shapeFilter': shapeFilter,
3415 'rankFilter': cleanRankFilter,
3416 'dtypeFilter': cleanDtypeFilter
3417 }
3418 return filterDict
3419 elif testType == 'negative':
3420 validator_info = validator(check=False, op=op)
3421 error_arguments = validator_info['param_reqs']
3422
3423 #Set parameters as required
3424 if error_arguments['rank'] != None:
3425 rankFilter = error_arguments['rank']
3426 else:
3427 rankFilter = cleanRankFilter
3428
3429 if error_arguments['dtype'] != None:
3430 dtypeFilter = error_arguments['dtype']
3431 else:
3432 dtypeFilter = cleanDtypeFilter
3433
3434 if error_arguments['shape'] != None:
3435 shapeFilter = error_arguments['shape']
3436 else:
3437 shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small
3438
3439 filterDict = {
3440 'shapeFilter': shapeFilter,
3441 'rankFilter': rankFilter,
3442 'dtypeFilter': dtypeFilter
3443 }
3444 return filterDict
3445
3446
Kevin Cheng550ccc52021-03-03 11:21:43 -08003447 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01003448 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08003449 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003450
3451 try:
3452 op = self.TOSA_OP_LIST[opName]
3453 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003454 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003455
3456 # Initialize a new random number generator
3457 self.rng = np.random.default_rng(self.random_seed)
3458
Kevin Cheng550ccc52021-03-03 11:21:43 -08003459 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003460
Eric Kunzee5e26762020-10-13 16:11:07 -07003461 # Test list consists of a tuple of:
3462 # (opName, testNameStr, dtype, shapeList, argumentsList)
3463 testList = []
Matthew Haddon1c00b712021-10-01 15:51:03 +01003464 if testType == 'negative' and "error_if_validators" in op:
3465 error_if_validators = op["error_if_validators"]
3466 else:
3467 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07003468
Matthew Haddon1c00b712021-10-01 15:51:03 +01003469 for validator in error_if_validators:
3470 if validator is not None:
3471 error_name = validator(check=False, op=op)['error_name']
Matthew Haddon1c00b712021-10-01 15:51:03 +01003472 else:
3473 error_name = None
3474
3475 filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator)
3476 cleanRankFilter = filterDict['rankFilter']
3477 cleanDtypeFilter = filterDict['dtypeFilter']
3478 cleanShapeFilter = filterDict['shapeFilter']
3479 #print(f"Filters: S {shapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
3480
3481 for r in cleanRankFilter:
Kevin Cheng1533b852021-09-01 12:51:58 -07003482 if opName.startswith("conv3d"):
3483 assert r == 5, "conv3d test must have input rank == 5"
Matthew Haddon1c00b712021-10-01 15:51:03 +01003484 for t in cleanDtypeFilter:
3485 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01003486 # Filter out by rank
3487 if shape is not None and len(shape) != r:
3488 continue
Matthew Haddon74567092021-07-16 15:38:20 +01003489 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003490 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003491
Matthew Haddon74567092021-07-16 15:38:20 +01003492 shapeStr = self.shapeStr(shapeList[0])
3493 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07003494
Matthew Haddon74567092021-07-16 15:38:20 +01003495 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
3496 argList = []
3497 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01003498 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003499 else:
Matthew Haddon74567092021-07-16 15:38:20 +01003500 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07003501
Matthew Haddon74567092021-07-16 15:38:20 +01003502 for argStr, args in argList:
Matthew Haddon1c00b712021-10-01 15:51:03 +01003503 if testType == 'positive':
3504 if argStr:
3505 testStr = "{}_{}_{}_{}".format(
3506 opName, shapeStr, typeStr, argStr
3507 )
3508 else:
3509 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
3510 elif testType == 'negative':
Matthew Haddone86fd342021-09-07 16:12:21 +01003511 if argStr:
3512 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
3513 opName, error_name, shapeStr, typeStr, argStr
3514 )
3515 else:
3516 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003517
3518 testList.append((opName, testStr, t, error_name, shapeList, args))
3519
3520 if testType == 'positive':
3521 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
3522 if "invalid_test_validators" in op:
3523 invalid_test_validators = op["invalid_test_validators"]
3524 clean_testList = []
3525 for test in testList:
3526 for validator_fcn in invalid_test_validators:
3527 remove_test = False
3528 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
3529 remove_test = True
3530 if not remove_test:
3531 clean_testList.append(test)
3532 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07003533
3534 return testList
3535
Matthew Haddone86fd342021-09-07 16:12:21 +01003536
3537 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07003538 try:
3539 op = self.TOSA_OP_LIST[opName]
3540 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003541 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003542
3543 # Create a serializer
3544 self.createSerializer(opName, testStr)
3545
Kevin Cheng550ccc52021-03-03 11:21:43 -08003546 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003547 if "error_if_validators" in op:
3548 error_if_validators = op["error_if_validators"]
3549 else:
3550 error_if_validators = None
3551
Kevin Cheng550ccc52021-03-03 11:21:43 -08003552 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003553 num_operands = pCount + cCount
3554
3555 if isinstance(dtype_or_dtypeList, list):
3556 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07003557 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01003558 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003559 else:
3560 dtypeList = [dtype_or_dtypeList] * (num_operands)
3561
Kevin Cheng93a16282021-08-31 16:14:03 -07003562 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01003563 assert (
3564 len(shapeList) == num_operands
3565 ), "shapeList length {} must match number of operands {}".format(
3566 len(shapeList), num_operands
3567 )
3568 assert (
3569 len(dtypeList) == num_operands
3570 ), "dtypeList length {} must match number of operands {}".format(
3571 len(dtypeList), num_operands
3572 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003573
3574 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003575 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003576 except KeyError:
3577 qgen = None
3578
3579 # Build the random tensor operands and the test
3580 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08003581
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003582 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003583
3584 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003585 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003586 else:
3587 qinfo = None
3588
3589 try:
3590 if error_if_validators is None:
3591 if qinfo is not None:
3592 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
3593 else:
3594 resultName = build_fcn(self, op, *tens, *testArgs)
3595 else:
3596 if qinfo is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003597 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name, qinfo)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003598 else:
3599 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name)
3600 except TypeError as e:
3601 print(
3602 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
3603 build_fcn, tens, testArgs
3604 )
3605 )
3606 raise e
3607
3608 if resultName is None:
3609 print("Invalid ERROR_IF tests created")
3610
3611 # Save the serialized test
3612 self.serialize("test")
3613
3614
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003615 def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003616 pCount, cCount = op["operands"]
3617
3618 tens = []
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003619 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name == None:
Jeremy Johnsonef509a42021-09-07 13:59:47 +01003620 # Make sure the operation does not cause value saturation - where
3621 # the number wraps due to limited number of bits to store the answer
3622 assert (
3623 pCount == 2 and cCount == 0
3624 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01003625 placeholders = []
3626 add = (op["op"] == Op.ADD)
3627 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
3628 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
3629 if add:
3630 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
3631 else:
3632 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
3633
3634 # Work out the saturation limits
3635 max_i32 = (1 << 31)-1
3636 min_i32 = -(1 << 31)
3637 max_arr = np.full(shapeList[1], max_i32)
3638 min_arr = np.full(shapeList[1], min_i32)
3639
3640 # Find how much values exceed the maximum/minimums
3641 sat_max_arr = np.maximum(res_arr - max_arr, 0)
3642 sat_min_arr = np.minimum(res_arr - min_arr, 0)
3643
3644 if not add:
3645 # Swap saturation values and negate values as we need to perform opposite operations
3646 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
3647
3648 # Create new array of unsaturated values by clipping values as needed
3649 b_unsat_arr = b_arr
3650 if (sat_max_arr != 0).any():
3651 # Clip values that cause saturation
3652 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
3653 # Reduce axes in unsaturated tensor to match original tensor
3654 for axis, dim in enumerate(b_arr.shape):
3655 if dim != b_unsat_arr.shape[axis]:
3656 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
3657 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
3658
3659 if (sat_min_arr != 0).any():
3660 # Clip values that cause saturation
3661 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
3662 # Reduce axes in unsaturated tensor to match original tensor
3663 for axis, dim in enumerate(b_arr.shape):
3664 if dim != b_unsat_arr.shape[axis]:
3665 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
3666 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
3667
3668 placeholders.append(
3669 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
3670 )
3671 placeholders.append(
3672 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
3673 )
3674
3675 tens.extend(placeholders)
3676 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
3677 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003678 assert (
3679 pCount == 2 and cCount == 0
3680 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08003681
3682 placeholders = []
3683 for idx, shape in enumerate(shapeList[:]):
3684 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07003685 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08003686 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07003687 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08003688 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07003689 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08003690 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
3691 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003692 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08003693 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003694 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07003695 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08003696
3697 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01003698 elif op["op"] == Op.SELECT:
3699 # Set datatype of condition tensor to boolean
3700 dtypeList[0] = DType.BOOL
3701 tens.extend(
3702 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
3703 )
3704 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003705 elif op["op"] == Op.INTDIV and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003706 assert (
3707 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01003708 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003709
3710 placeholders = []
3711
Matthew Haddon459443c2021-08-23 16:43:13 +01003712 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003713 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07003714 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003715 while True:
3716 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
3717 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
3718
3719 if (divisor_arr == 0).any():
3720 continue
3721
Kevin Cheng47315e12021-05-13 17:41:28 -07003722 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003723 continue
3724
3725 break
3726
3727 placeholders.append(
3728 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
3729 )
3730 placeholders.append(
3731 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
3732 )
3733
3734 tens.extend(placeholders)
3735 elif op["op"] == Op.MUL:
3736 assert (
3737 pCount == 2 and cCount == 0
3738 ), "Op.MUL must have 2 placeholders, 0 consts"
3739
3740 if dtypeList[0] == DType.FLOAT:
3741 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
3742 else:
3743 placeholders = []
3744
3745 # Make sure multiply result in int32 range
3746 shift = testArgs[0]
3747 if dtypeList[0] == DType.INT8:
3748 num_bits = 8
3749 elif dtypeList[0] == DType.INT16:
3750 num_bits = 16
3751 elif dtypeList[0] == DType.INT32:
3752 num_bits = 32
3753 else:
3754 raise Exception("OpMul: invalid input dtype")
3755
3756 for idx, shape in enumerate(shapeList[:]):
3757 low = -(2 ** (num_bits - 1))
3758 high = (2 ** (num_bits - 1)) - 1
3759
3760 a_arr = np.int32(
3761 self.rng.integers(low=low, high=high, size=shapeList[0])
3762 )
3763 b_arr = np.int32(
3764 self.rng.integers(low=low, high=high, size=shapeList[1])
3765 )
3766
3767 i = 0
3768 while True:
3769
3770 a_arr_64 = a_arr.astype(np.int64)
3771 b_arr_64 = b_arr.astype(np.int64)
3772
3773 if shift > 0:
3774 rounding = 1 << (shift - 1)
3775 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
3776 else:
3777 result_arr = a_arr_64 * b_arr_64
3778
3779 if (result_arr > -(2 ** 31)).all() and (
3780 result_arr <= ((2 ** 31) - 1)
3781 ).all():
3782 break
3783
3784 i = i + 1
3785 a_arr = a_arr // 2
3786 b_arr = b_arr // 2
3787
3788 placeholders.append(
3789 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
3790 )
3791 placeholders.append(
3792 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
3793 )
3794
3795 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01003796 elif op["op"] == Op.CONCAT:
3797 count = len(shapeList) - self.args.num_const_inputs_concat
3798 if count < 1:
3799 count = 1
3800 if self.args.num_const_inputs_concat == 0:
3801 count = len(shapeList)
3802
3803 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
3804 tens.extend(
3805 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
3806 )
3807 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08003808 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07003809 tens.extend(
3810 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
3811 )
3812 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07003813
Matthew Haddon1c00b712021-10-01 15:51:03 +01003814 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07003815
3816 def createDynamicOpLists(self):
3817
3818 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07003819 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003820
Kevin Cheng1533b852021-09-01 12:51:58 -07003821 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003822 testName = "conv2d_{}x{}".format(k[0], k[1])
3823 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3824 self.TOSA_OP_LIST[testName]["filter"] = k
3825 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003826
Kevin Cheng550ccc52021-03-03 11:21:43 -08003827 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3828 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3829 "depthwise_conv2d_TEMPLATE"
3830 ].copy()
3831 self.TOSA_OP_LIST[testName]["filter"] = k
3832 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003833
Kevin Cheng550ccc52021-03-03 11:21:43 -08003834 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3835 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3836 "transpose_conv2d_TEMPLATE"
3837 ].copy()
3838 self.TOSA_OP_LIST[testName]["filter"] = k
3839 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003840
Kevin Cheng1533b852021-09-01 12:51:58 -07003841 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3842 for k in KERNELS_3D:
3843 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3844 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3845 self.TOSA_OP_LIST[testName]["filter"] = k
3846 self.TOSA_OP_LIST[testName]["template"] = False
3847
Eric Kunzee5e26762020-10-13 16:11:07 -07003848 # Delete any templates after having created any dynamic ops
3849 # This is a two-pass operation because it's bad practice to delete
3850 # keys from dictionaries while iterating
3851 keyList = []
3852 for k in self.TOSA_OP_LIST:
3853 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003854 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07003855 keyList.append(k)
3856 continue
3857 except KeyError:
3858 pass
3859
3860 for k in keyList:
3861 del self.TOSA_OP_LIST[k]
3862
3863 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003864 """Fill in default fields for ops if they aren't already specified.
3865 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003866 for op in self.TOSA_OP_LIST:
3867
3868 # Required fields
3869 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003870 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003871 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003872 raise Exception(
3873 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3874 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003875
3876 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003877 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003878 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003879 raise Exception(
3880 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3881 op
3882 )
3883 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003884
3885 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003886 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003887 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003888 raise Exception(
3889 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3890 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003891
3892 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003893 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003894 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003895 raise Exception(
3896 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3897 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003898
3899 # Put in default rank range, if missing
3900 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003901 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003902 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003903 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003904
3905 # Tensor operator list
3906 # 'op': op name
3907 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003908 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3909 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003910 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3911 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08003912 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07003913
Kevin Cheng550ccc52021-03-03 11:21:43 -08003914 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
3915 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003916
Kevin Cheng550ccc52021-03-03 11:21:43 -08003917 TYPE_BOOL = [DType.BOOL]
3918 TYPE_FI32 = [DType.FLOAT, DType.INT32]
3919 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
3920 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003921
Kevin Cheng550ccc52021-03-03 11:21:43 -08003922 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07003923
Kevin Cheng1533b852021-09-01 12:51:58 -07003924 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003925 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003926 [DType.INT8, DType.INT8, DType.INT32],
3927 [DType.INT16, DType.INT8, DType.INT48],
3928 DType.FLOAT,
3929 ]
3930
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003931 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003932
3933 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003934 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003935 "argmax": {
3936 "op": Op.ARGMAX,
3937 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003938 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003939 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3940 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003941 "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evArgmaxOutputRankMismatch,
3942 TosaErrorValidator.evArgmaxOutputShapeMismatch, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
3943 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003944 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003945 "avg_pool2d": {
3946 "op": Op.AVG_POOL2D,
3947 "operands": (1, 0),
3948 "rank": (4, 4),
3949 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
3950 "qgen": TosaQuantGen.qgUnary,
3951 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003952 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
3953 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
3954 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
3955 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
3956 TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08003957 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003958 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003959 "conv2d_TEMPLATE": {
3960 "op": Op.CONV2D,
3961 "operands": (1, 2),
3962 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01003963 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003964 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003965 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003966 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003967 "template": True,
3968 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003969 # Templated operator. Filled in by createDynamicOpLists
3970 "conv3d_TEMPLATE": {
3971 "op": Op.CONV3D,
3972 "operands": (1, 2),
3973 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01003974 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07003975 "qgen": TosaQuantGen.qgConv,
3976 "types": TYPE_CONV,
3977 "template": True,
3978 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003979 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003980 "depthwise_conv2d_TEMPLATE": {
3981 "op": Op.DEPTHWISE_CONV2D,
3982 "operands": (1, 2),
3983 "filter": [1, 1],
3984 "rank": (4, 4),
3985 "build_fcn": (
3986 build_depthwise_conv2d,
3987 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01003988 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003989 ),
3990 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003991 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003992 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003993 "template": True,
3994 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003995 "fully_connected": {
3996 "op": Op.FULLY_CONNECTED,
3997 "operands": (1, 2),
3998 "rank": (2, 2),
3999 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
4000 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004001 "types": TYPE_CONV,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004002 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evWrongRank,
4003 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004004 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004005 "matmul": {
4006 "op": Op.MATMUL,
4007 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07004008 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08004009 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
4010 "qgen": TosaQuantGen.qgMatmul,
4011 "types": TYPE_NARROW_INT_FP,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004012 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
4013 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004014 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004015 "max_pool2d": {
4016 "op": Op.MAX_POOL2D,
4017 "operands": (1, 0),
4018 "rank": (4, 4),
4019 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
4020 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004021 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
4022 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
4023 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
4024 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08004025 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004026 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08004027 "transpose_conv2d_TEMPLATE": {
4028 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07004029 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004030 "rank": (4, 4),
4031 "build_fcn": (
4032 build_transpose_conv2d,
4033 TosaTensorGen.tgTransposeConv2D,
4034 TosaArgGen.agTransposeConv2D,
4035 ),
4036 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07004037 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01004038 "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004039 "template": True,
4040 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004041 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08004042 "clamp": {
4043 "op": Op.CLAMP,
4044 "operands": (1, 0),
4045 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
4046 "types": TYPE_NARROW_INT_FP,
4047 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004048 "sigmoid": {
4049 "op": Op.SIGMOID,
4050 "operands": (1, 0),
4051 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
4052 "types": TYPE_FP,
4053 },
4054 "tanh": {
4055 "op": Op.TANH,
4056 "operands": (1, 0),
4057 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
4058 "types": TYPE_FP,
4059 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004060 # Elementwise Binary Operators
4061 "add": {
4062 "op": Op.ADD,
4063 "operands": (2, 0),
4064 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4065 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004066 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4067 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004068 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004069 "arithmetic_right_shift": {
4070 "op": Op.ARITHMETIC_RIGHT_SHIFT,
4071 "operands": (2, 0),
4072 "build_fcn": (
4073 build_arithmetic_right_shift,
4074 TosaTensorGen.tgBroadcastFuzz,
4075 TosaArgGen.agArithmeticRightShift,
4076 ),
4077 "types": TYPE_INT,
4078 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004079 "bitwise_and": {
4080 "op": Op.BITWISE_AND,
4081 "operands": (2, 0),
4082 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4083 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004084 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4085 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004086 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004087 "bitwise_or": {
4088 "op": Op.BITWISE_OR,
4089 "operands": (2, 0),
4090 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4091 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004092 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4093 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004094 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004095 "bitwise_xor": {
4096 "op": Op.BITWISE_XOR,
4097 "operands": (2, 0),
4098 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4099 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004100 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4101 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004102 },
Matthew Haddon459443c2021-08-23 16:43:13 +01004103 "intdiv": {
4104 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004105 "operands": (2, 0),
4106 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4107 "types": [DType.INT32],
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004108 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4109 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07004110 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004111 "logical_and": {
4112 "op": Op.LOGICAL_AND,
4113 "operands": (2, 0),
4114 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4115 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004116 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4117 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004118 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004119 "logical_left_shift": {
4120 "op": Op.LOGICAL_LEFT_SHIFT,
4121 "operands": (2, 0),
4122 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4123 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004124 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4125 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004126 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004127 "logical_right_shift": {
4128 "op": Op.LOGICAL_RIGHT_SHIFT,
4129 "operands": (2, 0),
4130 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4131 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004132 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4133 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004134 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004135 "logical_or": {
4136 "op": Op.LOGICAL_OR,
4137 "operands": (2, 0),
4138 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4139 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004140 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4141 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004142 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004143 "logical_xor": {
4144 "op": Op.LOGICAL_XOR,
4145 "operands": (2, 0),
4146 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4147 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004148 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4149 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004150 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004151 "maximum": {
4152 "op": Op.MAXIMUM,
4153 "operands": (2, 0),
4154 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4155 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004156 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4157 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004158 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004159 "minimum": {
4160 "op": Op.MINIMUM,
4161 "operands": (2, 0),
4162 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4163 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004164 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4165 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004166 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004167 "mul": {
4168 "op": Op.MUL,
4169 "operands": (2, 0),
4170 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
4171 "types": TYPE_INT_FP,
4172 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004173 "pow": {
4174 "op": Op.POW,
4175 "operands": (2, 0),
4176 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
4177 "types": TYPE_FP,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004178 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4179 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004180 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004181 "sub": {
4182 "op": Op.SUB,
4183 "operands": (2, 0),
4184 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
4185 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004186 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4187 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004188 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004189 "table": {
4190 "op": Op.TABLE,
4191 # Use the automatic generation functions to create the input array
4192 # but create the table tensor in the build function, as it may be
4193 # a different type from the input
4194 "operands": (1, 0),
4195 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004196 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08004197 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004198 # Elementwise Unary operators
4199 "abs": {
4200 "op": Op.ABS,
4201 "operands": (1, 0),
4202 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4203 "types": TYPE_FI32,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004204 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4205 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004206 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004207 "bitwise_not": {
4208 "op": Op.BITWISE_NOT,
4209 "operands": (1, 0),
4210 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4211 "types": TYPE_INT,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004212 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4213 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004214 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004215 "ceil": {
4216 "op": Op.CEIL,
4217 "operands": (1, 0),
4218 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4219 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004220 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4221 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004222 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004223 "clz": {
4224 "op": Op.CLZ,
4225 "operands": (1, 0),
4226 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4227 "types": [DType.INT32],
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004228 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4229 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004230 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004231 "exp": {
4232 "op": Op.EXP,
4233 "operands": (1, 0),
4234 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4235 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004236 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4237 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004238 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004239 "floor": {
4240 "op": Op.FLOOR,
4241 "operands": (1, 0),
4242 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4243 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004244 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4245 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004246 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004247 "log": {
4248 "op": Op.LOG,
4249 "operands": (1, 0),
4250 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4251 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004252 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4253 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004254 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004255 "logical_not": {
4256 "op": Op.LOGICAL_NOT,
4257 "operands": (1, 0),
4258 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4259 "types": TYPE_BOOL,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004260 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4261 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004262 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004263 "negate": {
4264 "op": Op.NEGATE,
4265 "operands": (1, 0),
4266 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4267 "qgen": TosaQuantGen.qgUnary,
4268 "types": TYPE_INT_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004269 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
4270 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
4271 TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004272 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004273 "reciprocal": {
4274 "op": Op.RECIPROCAL,
4275 "operands": (1, 0),
4276 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4277 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004278 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4279 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004280 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004281 "rsqrt": {
4282 "op": Op.RSQRT,
4283 "operands": (1, 0),
4284 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4285 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004286 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
4287 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004288 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004289 # Elementwise Ternary operators
4290 "select": {
4291 "op": Op.SELECT,
4292 "operands": (3, 0),
4293 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
4294 "types": TYPE_FIB,
4295 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004296 # Comparison operators
4297 "equal": {
4298 "op": Op.EQUAL,
4299 "operands": (2, 0),
4300 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4301 "types": TYPE_FI32,
4302 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004303 "greater_equal": {
4304 "op": Op.GREATER_EQUAL,
4305 "operands": (2, 0),
4306 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4307 "types": TYPE_FI32,
4308 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004309 "greater": {
4310 "op": Op.GREATER,
4311 "operands": (2, 0),
4312 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
4313 "types": TYPE_FI32,
4314 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004315 # Reduction operators
4316 "reduce_all": {
4317 "op": Op.REDUCE_ALL,
4318 "operands": (1, 0),
4319 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4320 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004321 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4322 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4323 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004324 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004325 "reduce_any": {
4326 "op": Op.REDUCE_ANY,
4327 "operands": (1, 0),
4328 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4329 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004330 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4331 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4332 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004333 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004334 "reduce_max": {
4335 "op": Op.REDUCE_MAX,
4336 "operands": (1, 0),
4337 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4338 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004339 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4340 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4341 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004342 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004343 "reduce_min": {
4344 "op": Op.REDUCE_MAX,
4345 "operands": (1, 0),
4346 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4347 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004348 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4349 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4350 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004351 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004352 "reduce_product": {
4353 "op": Op.REDUCE_PRODUCT,
4354 "operands": (1, 0),
4355 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4356 "types": TYPE_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004357 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4358 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4359 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004360 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004361 "reduce_sum": {
4362 "op": Op.REDUCE_SUM,
4363 "operands": (1, 0),
4364 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4365 "types": TYPE_FI32,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004366 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
4367 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4368 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08004369 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004370 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004371 "concat": {
4372 "op": Op.CONCAT,
4373 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01004374 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004375 "types": TYPE_FIB,
4376 },
4377 "pad": {
4378 "op": Op.PAD,
4379 "operands": (1, 0),
4380 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
4381 "qgen": TosaQuantGen.qgPad,
4382 "types": TYPE_FIB,
4383 },
4384 "reshape": {
4385 "op": Op.RESHAPE,
4386 "operands": (1, 0),
4387 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
4388 "types": TYPE_FIB,
4389 },
4390 "reverse": {
4391 "op": Op.REVERSE,
4392 "operands": (1, 0),
4393 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4394 "types": TYPE_FIB,
4395 },
4396 "slice": {
4397 "op": Op.SLICE,
4398 "operands": (1, 0),
4399 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
4400 "types": TYPE_FIB,
4401 },
4402 "tile": {
4403 "op": Op.TILE,
4404 "operands": (1, 0),
4405 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
4406 "types": TYPE_FIB,
4407 },
4408 "transpose": {
4409 "op": Op.TRANSPOSE,
4410 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01004411 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004412 "build_fcn": (
4413 build_transpose,
4414 TosaTensorGen.tgBasic,
4415 TosaArgGen.agTranspose,
4416 ),
4417 "types": TYPE_FIB,
4418 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004419 # Data nodes
4420 "const": {
4421 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004422 "operands": (0, 1),
4423 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08004424 "types": TYPE_FIB,
4425 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004426 "identity": {
4427 "op": Op.IDENTITY,
4428 "operands": (1, 0),
4429 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4430 "types": TYPE_FIB,
4431 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004432 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004433 "gather": {
4434 "op": Op.GATHER,
4435 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4436 "operands": (1, 0),
4437 "rank": (3, 3),
4438 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
4439 "types": TYPE_INT_FP,
4440 },
4441 "scatter": {
4442 "op": Op.SCATTER,
4443 # Only specify 'values_in' tensor here.
4444 #'indices' and 'input' are generated in op building stage
4445 "operands": (2, 0),
4446 "rank": (3, 3),
4447 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
4448 "types": TYPE_INT_FP,
4449 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004450 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004451 "resize": {
4452 "op": Op.RESIZE,
4453 "operands": (1, 0),
4454 "rank": (4, 4),
4455 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
4456 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01004457 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
4458 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
4459 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
Matthew Haddon848efb42021-09-09 12:30:53 +01004460 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004461 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
4462 TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004463 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004464 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004465 "cast": {
4466 "op": Op.CAST,
4467 "operands": (1, 0),
4468 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
4469 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
4470 },
4471 "rescale": {
4472 "op": Op.RESCALE,
4473 "operands": (1, 0),
Matthew Haddonc2025212021-10-08 21:21:05 +01004474 "rank": (1,4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004475 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004476 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Matthew Haddonc2025212021-10-08 21:21:05 +01004477 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, TosaErrorValidator.evScaleTrue,
4478 TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
4479 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004480 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004481 # Custom
4482 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004483 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004484 # Two varients of cond_if, one that generates one of two constant tensors (no
4485 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4486 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004487 "cond_if_const": {
4488 "op": Op.COND_IF,
4489 "operands": (0, 2),
4490 "build_fcn": (
4491 build_cond_if_const,
4492 TosaTensorGen.tgBasic,
4493 TosaArgGen.agCondIf,
4494 ),
4495 "types": [DType.BOOL],
4496 },
4497 "cond_if_binary": {
4498 "op": Op.COND_IF,
4499 "operands": (2, 0),
4500 "build_fcn": (
4501 build_cond_if_binary,
4502 TosaTensorGen.tgBasic,
4503 TosaArgGen.agCondIf,
4504 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004505 "types": TYPE_INT_FP,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004506 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004507 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004508 "while_loop": {
4509 "op": Op.WHILE_LOOP,
4510 "operands": (0, 1),
4511 "build_fcn": (
4512 build_while_loop,
4513 TosaTensorGen.tgBasic,
4514 TosaArgGen.agWhileLoop,
4515 ),
4516 "types": [DType.INT32],
4517 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004518 }
4519
Kevin Cheng550ccc52021-03-03 11:21:43 -08004520
Eric Kunzee5e26762020-10-13 16:11:07 -07004521class OutputShaper:
4522 # Methods in this class compute the expected output shape and datatype
4523 # for common classes of operations
4524 def __init__(self):
4525 pass
4526
4527 # These methods return arguments that can be used for
4528 # creating a new output tensor
4529 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004530 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4531 if error_name != ErrorIf.RankMismatch:
4532 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004533 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004534
4535 shape = []
4536 for i in range(len(a.shape)):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004537 if a.shape[i] == 1 and error_name == None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004538 shape.append(b.shape[i])
4539 else:
4540 shape.append(a.shape[i])
4541
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004542 if error_name == ErrorIf.WrongOutputType:
4543 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4544 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4545 outputDType = rng.choice(wrong_dtypes)
4546 else:
4547 outputDType = a.dtype
4548
4549 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004550
4551 @staticmethod
4552 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004553 assert len(a.shape) == len(b.shape)
4554 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004555
4556 shape = []
4557 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004558 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004559 shape.append(a.shape[i])
4560
Kevin Cheng550ccc52021-03-03 11:21:43 -08004561 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004562
4563 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004564 def unaryOp(ser, rng, a, error_name=None):
4565 if error_name == ErrorIf.WrongOutputType:
4566 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4567 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4568 outputDType = rng.choice(wrong_dtypes)
4569 else:
4570 outputDType = a.dtype
4571
4572 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004573
4574 @staticmethod
4575 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004576 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
4577 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004578
4579 shape = []
4580 for i in range(len(a.shape)):
4581 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4582
Kevin Cheng550ccc52021-03-03 11:21:43 -08004583 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004584
4585 @staticmethod
4586 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004587 assert len(a.shape) == len(b.shape)
4588 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004589
4590 # Do broadcast
4591 shape = []
4592 for i in range(len(a.shape)):
4593 if a.shape[i] == 1:
4594 shape.append(b.shape[i])
4595 else:
4596 shape.append(a.shape[i])
4597
4598 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08004599 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07004600
4601 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004602 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004603 shape = a.shape.copy()
Matthew Haddond6ce7252021-09-29 15:35:44 +01004604 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ShapeOfAxisNotOne]:
4605 shape[axis] = 1
4606 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4607 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004608
Matthew Haddond6ce7252021-09-29 15:35:44 +01004609 if error_name == ErrorIf.WrongOutputType:
4610 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4611 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4612 outputDType = rng.choice(wrong_dtypes)
4613 else:
4614 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004615
Matthew Haddond6ce7252021-09-29 15:35:44 +01004616 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004617
4618 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004619 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004620 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004621
4622 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4623 del shape[axis]
4624
4625 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4626 remove = rng.choice([True, False])
4627 if remove and len(shape) > 1:
4628 del shape[0]
4629 else:
4630 shape.append(1)
4631 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4632 for i in range(len(shape)):
4633 shape[i] = shape[i] + rng.integers(1, 10)
4634
4635 if error_name == ErrorIf.WrongOutputType:
4636 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4637 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4638 outputDType = rng.choice(wrong_dtypes)
4639 else:
4640 outputDType = DType.INT32
4641
4642 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004643
4644 @staticmethod
4645 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
4646
4647 # IFM: NHWC
4648 # Filter: OHWI
4649 # OFM: NHWC
4650
4651 if len(padding) == 2:
4652 # Expand padding to 4 parameters in the case of transpose_conv2d
4653 # From H,W to T,B,L,R
4654 padding = [padding[0], padding[0], padding[1], padding[1]]
4655
Kevin Cheng550ccc52021-03-03 11:21:43 -08004656 h = (
4657 ifm.shape[1]
4658 - filter.shape[1]
4659 - (filter.shape[1] - 1) * (dilations[0] - 1)
4660 + padding[0]
4661 + padding[1]
4662 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004663
Kevin Cheng550ccc52021-03-03 11:21:43 -08004664 w = (
4665 ifm.shape[2]
4666 - filter.shape[2]
4667 - (filter.shape[2] - 1) * (dilations[1] - 1)
4668 + padding[2]
4669 + padding[3]
4670 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004671
Eric Kunzee5e26762020-10-13 16:11:07 -07004672 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4673
Kevin Cheng3a478572021-01-22 17:21:02 -08004674 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004675 out_dtype = DType.INT32
4676 elif ifm.dtype == DType.INT16:
4677 out_dtype = DType.INT48
4678 elif ifm.dtype == DType.FLOAT:
4679 out_dtype = DType.FLOAT
4680 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004681 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004682
Kevin Cheng550ccc52021-03-03 11:21:43 -08004683 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004684
4685 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -07004686 def conv3dOp(ser, ifm, filter, strides, padding, dilations):
4687
4688 # IFM: NDHWC
4689 # Filter: ODHWI
4690 # OFM: NDHWC
4691
4692 d = (
4693 ifm.shape[1]
4694 - filter.shape[1]
4695 - (filter.shape[1] - 1) * (dilations[0] - 1)
4696 + padding[0]
4697 + padding[1]
4698 ) // strides[0] + 1
4699
4700 h = (
4701 ifm.shape[2]
4702 - filter.shape[2]
4703 - (filter.shape[2] - 1) * (dilations[1] - 1)
4704 + padding[2]
4705 + padding[3]
4706 ) // strides[1] + 1
4707
4708 w = (
4709 ifm.shape[3]
4710 - filter.shape[3]
4711 - (filter.shape[3] - 1) * (dilations[2] - 1)
4712 + padding[4]
4713 + padding[5]
4714 ) // strides[2] + 1
4715
4716 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4717
4718 if ifm.dtype == DType.INT8:
4719 out_dtype = DType.INT32
4720 elif ifm.dtype == DType.INT16:
4721 out_dtype = DType.INT48
4722 elif ifm.dtype == DType.FLOAT:
4723 out_dtype = DType.FLOAT
4724 else:
4725 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
4726
4727 return ser.addOutput(ofm_shape, out_dtype)
4728
4729 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -07004730 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
4731 # IFM: NHWC
4732 # Filter: HWCM
4733 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08004734 h = (
4735 ifm.shape[1]
4736 - filter.shape[0]
4737 - (filter.shape[0] - 1) * (dilations[0] - 1)
4738 + padding[0]
4739 + padding[1]
4740 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004741
Kevin Cheng550ccc52021-03-03 11:21:43 -08004742 w = (
4743 ifm.shape[2]
4744 - filter.shape[1]
4745 - (filter.shape[1] - 1) * (dilations[1] - 1)
4746 + padding[2]
4747 + padding[3]
4748 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004749
Eric Kunzee5e26762020-10-13 16:11:07 -07004750 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4751
Kevin Cheng3a478572021-01-22 17:21:02 -08004752 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004753 out_dtype = DType.INT32
4754 elif ifm.dtype == DType.INT16:
4755 out_dtype = DType.INT48
4756 elif ifm.dtype == DType.FLOAT:
4757 out_dtype = DType.FLOAT
4758 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004759 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004760
Kevin Cheng550ccc52021-03-03 11:21:43 -08004761 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004762
4763 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004764 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004765 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004766 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004767 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004768 h = 1
4769 w = 1
4770 else:
4771 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
4772 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
4773
4774 if error_name == ErrorIf.PoolingOutputShapeMismatch:
4775 choices = [1, 2, 3, 4, 5]
4776 h = h + rng.choice(choices)
4777 w = w + rng.choice(choices)
Eric Kunzee5e26762020-10-13 16:11:07 -07004778
Eric Kunzee5e26762020-10-13 16:11:07 -07004779 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004780
4781 if error_name == ErrorIf.WrongOutputType:
4782 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4783 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4784 outputDType = rng.choice(wrong_dtypes)
4785 else:
4786 outputDType = ifm.dtype
4787
4788 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004789
4790 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004791 def fullyConnectedOp(ser, rng, input, filter, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004792 # input: N, IC
4793 # filter: OC, IC
4794 # output: N, OC
4795
4796 output_shape = [input.shape[0], filter.shape[0]]
4797
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004798 if error_name == ErrorIf.WrongOutputType:
4799 if input.dtype == DType.INT8:
4800 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
4801 elif input.dtype == DType.INT16:
4802 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
4803 elif input.dtype == DType.FLOAT:
4804 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
4805 out_dtype = rng.choice(a=incorrect_types)
4806 elif input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004807 out_dtype = DType.INT32
4808 elif input.dtype == DType.INT16:
4809 out_dtype = DType.INT48
4810 elif input.dtype == DType.FLOAT:
4811 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004812 elif error_name == ErrorIf.WrongInputType:
4813 # Pick some potentially correct output dtype if input type is incorrect
4814 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004815 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004816 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004817
Kevin Cheng550ccc52021-03-03 11:21:43 -08004818 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004819
4820 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004821 def matmulOp(ser, rng, a, b, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004822 # a: N, H, C
4823 # b: N, C, W
4824 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004825
Kevin Cheng2d60f002021-06-09 14:18:32 -07004826 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004827
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004828 if error_name == ErrorIf.WrongOutputType:
4829 if a.dtype == DType.INT8:
4830 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
4831 elif a.dtype == DType.INT16:
4832 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
4833 elif a.dtype == DType.FLOAT:
4834 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
4835 out_dtype = rng.choice(a=incorrect_types)
4836 elif a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004837 out_dtype = DType.INT32
4838 elif a.dtype == DType.INT16:
4839 out_dtype = DType.INT48
4840 elif a.dtype == DType.FLOAT:
4841 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004842 elif error_name == ErrorIf.WrongInputType:
4843 # Pick some potentially correct output dtype if input type is incorrect
4844 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004845 else:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004846 raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004847
Kevin Cheng550ccc52021-03-03 11:21:43 -08004848 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004849
4850 @staticmethod
Matthew Haddon818ab902021-07-27 09:12:49 +01004851 def concatOp(ser, axis, *a):
4852 input1 = a[0]
4853 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004854
Matthew Haddon818ab902021-07-27 09:12:49 +01004855 output_shape = input1.shape.copy()
Eric Kunzee5e26762020-10-13 16:11:07 -07004856
Matthew Haddon818ab902021-07-27 09:12:49 +01004857 output_shape[axis] = input1.shape[axis]
4858
4859 for tensor in remaining_inputs:
4860 output_shape[axis] += tensor.shape[axis]
4861
4862 return ser.addOutput(output_shape, input1.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004863
4864 @staticmethod
4865 def padOp(ser, a, padding):
4866
4867 output_shape = a.shape.copy()
4868
4869 for i in range(len(output_shape)):
4870 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4871
Kevin Cheng550ccc52021-03-03 11:21:43 -08004872 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004873
4874 @staticmethod
4875 def reshapeOp(ser, a, shape):
4876 output_shape = shape.copy()
4877
4878 totalElements = 1
4879 for i in a.shape:
4880 totalElements *= i
4881
4882 # If there are any -1 elements, figure out what that dimension must be
4883 totalOutputElements = 1
4884 for i in output_shape:
4885 if i != -1:
4886 totalOutputElements *= i
4887
4888 # And fill it in
4889 for i in range(len(output_shape)):
4890 if output_shape[i] == -1:
4891 output_shape[i] = totalElements // totalOutputElements
4892
Kevin Cheng550ccc52021-03-03 11:21:43 -08004893 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004894
4895 @staticmethod
4896 def sliceOp(ser, a, begin, size):
4897
4898 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004899 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004900
4901 @staticmethod
4902 def tileOp(ser, a, multiples):
4903
4904 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004905 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004906
4907 for i in range(len(output_shape)):
4908 output_shape[i] = a.shape[i] * multiples[i]
4909
Kevin Cheng550ccc52021-03-03 11:21:43 -08004910 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004911
4912 @staticmethod
4913 def transposeOp(ser, a, perms):
4914 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004915 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004916
4917 for i in range(len(output_shape)):
4918 output_shape[i] = a.shape[perms[i]]
4919
Kevin Cheng550ccc52021-03-03 11:21:43 -08004920 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004921
4922 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08004923 def gatherOp(ser, values, indices):
4924 assert len(values.shape) == 3
4925 assert len(indices.shape) == 2
4926 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004927
Kevin Cheng77d0f762020-11-24 10:26:32 -08004928 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4929
Kevin Cheng550ccc52021-03-03 11:21:43 -08004930 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004931
4932 @staticmethod
4933 def scatterOp(ser, values_in, indices, input):
4934 assert len(values_in.shape) == 3
4935 assert len(indices.shape) == 2
4936 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004937 assert values_in.shape[0] == indices.shape[0] # N
4938 assert input.shape[1] == indices.shape[1] # W
4939 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004940
4941 output_shape = values_in.shape
4942
Kevin Cheng550ccc52021-03-03 11:21:43 -08004943 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004944
4945 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004946 def tableOp(ser, input, table_dtype):
4947 # Same shape as the input, but dtype dependent on table dtype
4948 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
4949 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
4950 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004951
4952 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004953 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004954 serializer,
4955 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004956 input,
4957 mode,
4958 stride,
4959 offset,
4960 shift,
4961 stride_fp,
4962 offset_fp,
4963 output_dims,
4964 input_dtype,
4965 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01004966 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08004967 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01004968 if error_name == ErrorIf.WrongRank:
4969 output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
4970 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004971 if error_name == ErrorIf.BatchMismatch:
4972 output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
4973 elif error_name == ErrorIf.ChannelMismatch:
4974 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
4975 else:
4976 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004977
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004978 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004979
4980 @staticmethod
4981 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004982 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004983
4984 @staticmethod
4985 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08004986 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004987 out_dtype = DType.INT32
4988 elif ifm.dtype == DType.INT16:
4989 out_dtype = DType.INT48
4990 elif ifm.dtype == DType.FLOAT:
4991 out_dtype = DType.FLOAT
4992 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004993 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004994
Kevin Cheng550ccc52021-03-03 11:21:43 -08004995 return ser.addOutput(output_shape, out_dtype)