blob: 928ac0ef2c0768fcabc12b7b9fd89a92c8ab4a78 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
48DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080049Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070050ResizeMode = tosa.ResizeMode.ResizeMode()
51
52class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
54
Eric Kunzee5e26762020-10-13 16:11:07 -070055 def __init__(self):
56 pass
57
58 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010059 def getQinfo(testGen, dtype, error_name=None):
60
Les Bell30e46802021-07-23 09:43:31 +010061 if dtype == DType.INT8:
62 return testGen.randInt(-128, 128)
Matthew Haddone4ecdb22021-09-28 11:38:21 +010063 elif dtype == DType.UINT8:
Les Bell30e46802021-07-23 09:43:31 +010064 return testGen.randInt(0, 256)
Matthew Haddone4ecdb22021-09-28 11:38:21 +010065 elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
66 zero_point = testGen.randInt(-128, 128)
67 if zero_point == 0:
68 zero_point = 1
69 return zero_point
Les Bell30e46802021-07-23 09:43:31 +010070 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070071
72 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010073 def qgUnary(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070074 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone4ecdb22021-09-28 11:38:21 +010075 if error_name == ErrorIf.InputZeroPointNotZero:
76 qinfo.UnaryQuantInfo(
77 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype)
78 )
79 elif error_name == ErrorIf.OutputZeroPointNotZero:
80 qinfo.UnaryQuantInfo(
81 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype, error_name)
82 )
83 else:
84 qinfo.UnaryQuantInfo(
85 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
86 )
Eric Kunzee5e26762020-10-13 16:11:07 -070087 return qinfo
88
89 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010090 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070091 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010092 if isinstance(dtype_or_dtypeList, list):
93 # a list of [input, weights, accumulator] dtypes
94 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070095 else:
Les Bell30e46802021-07-23 09:43:31 +010096 # an int, [input, weights, accumulator] dtypes are the same
97 dtypeList = [dtype_or_dtypeList] * 3
98 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
99 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
100 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -0700101 return qinfo
102
103 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100104 def qgMatmul(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700105 qinfo = ts.TosaSerializerQuantInfo()
Kevin Chengacb550f2021-06-29 15:32:19 -0700106 qinfo.MatMulQuantInfo(
107 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
108 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700109 return qinfo
110
111 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100112 def qgPad(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700113 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +0100114 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700115 return qinfo
116
117 @staticmethod
118 def computeMultiplierAndShift(scaleFp, scale32):
119 # Derived from computeMultiplierAndShiftTosaScale32
120 # Provide a floating-point scaling factor and the scale32 parameter
121 # to compute the multiplier and shift
122
123 if scale32:
124 scaleBits = 31
125 else:
126 scaleBits = 15
127
128 m, shift = math.frexp(scaleFp)
129
130 if scaleFp < 0.0:
131 m = -m
132
133 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800134 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700135
136 if multiplier == (1 << scaleBits):
137 multiplier = multiplier // 2
138 shift = shift + 1
139
140 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100141 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
142
143 # Adjust multiplier such that shift is in allowed value range.
144 if shift == 0:
145 multiplier = multiplier // 4
146 shift = shift + 2
147 elif shift == 1:
148 multiplier = multiplier // 2
149 shift = shift + 1
150 elif shift == 63:
151 multiplier = multiplier * 2
152 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700153
Kevin Cheng550ccc52021-03-03 11:21:43 -0800154 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100155 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700156
157 return multiplier, shift
158
159
Kevin Cheng550ccc52021-03-03 11:21:43 -0800160class TosaTensorGen:
161 """Tensor generators create a shape list for the placeholder and const tensor
162 data operands for the operator. The actual random data is generated separately for each test."""
163
Eric Kunzee5e26762020-10-13 16:11:07 -0700164 def __init__(self):
165 pass
166
167 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100168 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800169 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700170 shape = testGen.makeShape(rank)
171
172 shape_list = []
173 for i in range(pl + const):
174 shape_list.append(shape.copy())
175
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100176 if error_name == ErrorIf.RankMismatch:
177 if rank == 1 and i != 1:
178 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
179 elif i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
181
Eric Kunzee5e26762020-10-13 16:11:07 -0700182 return shape_list
183
184 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100185 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800186 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700187
Matthew Haddon848efb42021-09-09 12:30:53 +0100188 if error_name != ErrorIf.WrongRank:
189 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700190
191 shape = testGen.makeShape(rank)
192
193 # Constrict the batch size?
194 if testGen.args.max_batch_size:
195 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100196 # Constrict dimension size for large ranks
197 if rank > 4:
198 shape[4] = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700199
200 shape_list = []
201 for i in range(pl + const):
202 shape_list.append(shape.copy())
203
204 return shape_list
205
206 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100207 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800208 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800209
Kevin Cheng550ccc52021-03-03 11:21:43 -0800210 assert pl == 2
211 assert const == 0
212 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800213
214 values_in_shape = testGen.makeShape(rank)
215
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100216 # ignore max batch size if target shape is set
217 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800218 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
219
Kevin Cheng550ccc52021-03-03 11:21:43 -0800220 W = testGen.randInt(
221 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
222 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100223 # Constrict W if one dimension is too large to keep tensor size reasonable
224 if max(values_in_shape) > 5000:
225 W = testGen.randInt(0, 16)
226
Kevin Cheng77d0f762020-11-24 10:26:32 -0800227 input_shape = [values_in_shape[0], W, values_in_shape[2]]
228
229 shape_list = []
230 shape_list.append(values_in_shape.copy())
231 shape_list.append(input_shape.copy())
232
233 return shape_list
234
235 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100236 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700237 shape = testGen.makeShape(rank)
238
Kevin Cheng550ccc52021-03-03 11:21:43 -0800239 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700240
241 shape_list = []
242
243 # Choose one of the inputs to broadcast
244 bcast_idx = testGen.randInt(0, pl + const)
245 for i in range(pl + const):
246 shape_bcast = shape.copy()
247
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100248 if error_name == ErrorIf.RankMismatch:
249 bcast_idx = -1 # Turn off broadcast because we are not testing it
250 if rank == 1 and i != 1:
251 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
252 elif i != 1:
253 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
254
Eric Kunzee5e26762020-10-13 16:11:07 -0700255 # If the chosen input, pick a random index to broadcast
256 if i == bcast_idx:
257 fuzz_idx = testGen.randInt(0, rank)
258 shape_bcast[fuzz_idx] = 1
259
260 shape_list.append(shape_bcast)
261
262 return shape_list
263
264 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100265 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800266 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700267
Kevin Cheng550ccc52021-03-03 11:21:43 -0800268 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700269
270 # IFM dimensions are NHWC
271 ifm_shape = testGen.makeShape(rank)
272
273 # Constrict the batch size?
274 if testGen.args.max_batch_size:
275 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
276
277 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800278 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700279
280 # Generate a random OFM depth
281 ofm_depth = testGen.makeShape(1)[0]
282
283 # The filter dimensions are OHWI
284 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
285
286 # The bias is OC
287 bias_shape = np.asarray([ofm_depth])
288
289 return [ifm_shape, filter_shape, bias_shape]
290
291 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100292 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700293 pl, const = op["operands"]
294
295 assert rank == 5
296
297 # IFM dimensions are NDHWC
298 ifm_shape = testGen.makeShape(rank)
299
300 # Constrict the batch size?
301 if testGen.args.max_batch_size:
302 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
303
304 # Get the filter depth/height/width from the operator parameters
305 filter_dhw = op["filter"]
306
307 # Generate a random OFM channel
308 ofm_channel = testGen.makeShape(1)[0]
309
310 # The filter dimensions are ODHWI
311 filter_shape = np.asarray(
312 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
313 )
314
315 # The bias is OC
316 bias_shape = np.asarray([ofm_channel])
317
318 return [ifm_shape, filter_shape, bias_shape]
319
320 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100321 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800322 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700323
Kevin Cheng550ccc52021-03-03 11:21:43 -0800324 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700325
326 # IFM dimensions are NHWC
327 ifm_shape = testGen.makeShape(rank)
328
329 # Constrict the batch size?
330 if testGen.args.max_batch_size:
331 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
332
333 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800334 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700335
336 # Generate a random OFM depth
337 ofm_depth = testGen.makeShape(1)[0]
338
339 # The filter dimensions are OHWI
340 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
341
Kevin Cheng989cb052021-04-28 16:29:44 -0700342 # The bias is OC
343 bias_shape = np.asarray([ofm_depth])
344
345 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700346
347 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100348 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800349 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700350
Kevin Cheng550ccc52021-03-03 11:21:43 -0800351 assert rank == 4
352 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700353
354 # IFM dimensions are NHWC
355 ifm_shape = testGen.makeShape(rank)
356
357 # Constrict the batch size?
358 if testGen.args.max_batch_size:
359 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
360
361 # Get the filter height/width from the operator parameters
362 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800363 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700364
365 # Generate a random OFM depth, but don't let it get too big because
366 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800367 filter_m = (
368 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
369 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700370
371 # The filter dimensions are HWCM
372 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
373
374 # The bias is M * C
375 bias_shape = np.asarray([ifm_shape[3] * filter_m])
376
377 return [ifm_shape, filter_shape, bias_shape]
378
379 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100380 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800381 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700382
Kevin Cheng550ccc52021-03-03 11:21:43 -0800383 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700384
385 input_shape = testGen.makeShape(rank)
Kevin Chengacb550f2021-06-29 15:32:19 -0700386 filter_oc = testGen.rng.integers(
387 low=testGen.args.tensor_shape_range[0],
388 high=testGen.args.tensor_shape_range[1],
389 size=1,
390 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700391 filter_shape = np.asarray([filter_oc, input_shape[1]])
392
393 bias_shape = np.asarray([filter_oc])
394
395 return [input_shape, filter_shape, bias_shape]
396
397 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100398 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800399 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700400
Kevin Cheng2d60f002021-06-09 14:18:32 -0700401 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800402 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700403
404 a_shape = testGen.makeShape(rank)
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100405 # Get a random number for b_oc even if target shape is defined
406 b_oc = np.int32(
407 testGen.rng.integers(
408 low=testGen.args.tensor_shape_range[0],
409 high=testGen.args.tensor_shape_range[1],
410 size=1,
411 )
412 )[0]
413 # If N or H is large let b_oc be 1 to reduce output tensor size
414 if max(a_shape) > 1000:
415 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700416
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100417 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700418 return [a_shape, b_shape]
419
Matthew Haddon818ab902021-07-27 09:12:49 +0100420 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100421 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100422 pl, const = opName["operands"]
423 shape = testGen.makeShape(rank)
424
425 # Create extra tensors to concat.
426 # Take into account value of pl when getting maximum number of concats
427 num_tensors = testGen.randInt(0, 4)
428 shape_list = []
429 for i in range(pl + const + num_tensors):
430 shape_list.append(shape.copy())
431
432 return shape_list
433
434 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100435 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100436 # Split concat shape along axis to allow for multiple const inputs
437 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100438 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddon818ab902021-07-27 09:12:49 +0100439 return shapeList
440
Jeremy Johnson960985a2021-10-06 10:58:14 +0100441 # Create copy of shape we are going to split (so we don't alter shapeList)
442 shape = shapeList[0].copy()
443 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100444 new_shapeList = [shape.copy()]
445 length_on_axis = shape[axis]
446 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700447 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100448 # Calculate split on axis and remaining value
449 split_shape_val = int(shape[axis] / 2)
450 remaining_length = remaining_length - split_shape_val
451
452 # Append new shape, and set remaining shape
453 shape[axis] = split_shape_val
454 new_shapeList.append(shape.copy())
455 shape[axis] = remaining_length
456 if i == len(shapeList) - 3:
457 new_shapeList.append(shape.copy())
458
459 return new_shapeList
460
461
Eric Kunzee5e26762020-10-13 16:11:07 -0700462class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800463 """Argument generators create exhaustive or random lists of attributes for operators that take
464 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
465 tuples where the descriptive_name is appended to the test name and the arglist is expanded
466 as arguments to the operator build function."""
467
Eric Kunzee5e26762020-10-13 16:11:07 -0700468 def __init__(self):
469 pass
470
471 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100472 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800473 """A trivial argument generator for operators that don't take any
474 non-tensor arguments"""
475 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700476
477 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100478 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800479 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700480 axes = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700481 shape = shapeList[0]
482
Matthew Haddond6ce7252021-09-29 15:35:44 +0100483 if error_name == ErrorIf.AxisSmallerZero:
484 small_axis = testGen.rng.integers(-5, 0)
485 axes.append(("axis{}".format(small_axis), [small_axis]))
486 elif error_name == ErrorIf.AxisLargerRank:
487 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
488 axes.append(("axis{}".format(large_axis), [large_axis]))
489 else:
490 for a in range(0, len(shape)):
491 axes.append(("axis{}".format(a), [a]))
492
Eric Kunzee5e26762020-10-13 16:11:07 -0700493 return axes
494
495 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100496 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700497 arg_list = []
498
499 ifm_shape = shapeList[0]
500 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100501 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
502 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700503
Les Bell7aa69f42021-09-20 10:44:07 +0100504 # Check the rank
505 rank = 5 if opName.startswith("conv3d") else 4
506 assert len(ifm_shape) == rank
507 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700508
Les Bell7aa69f42021-09-20 10:44:07 +0100509 # kernel rank omits batch and channels
510 k_rank = rank - 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700511
Les Bell7aa69f42021-09-20 10:44:07 +0100512 # Generate comprehensive argument lists
513 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
514 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
515 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
516 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
517 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
518 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700519
Les Bell7aa69f42021-09-20 10:44:07 +0100520 # add some oversize argument values
521 if max(ifm_shape) < 64:
522 bigPadding = 9
523 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
524 bigStride = 8
525 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
526 bigDilation = 7
527 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100528
529 # There are too many parameter combinations, so generate them sparsely
Les Bell7aa69f42021-09-20 10:44:07 +0100530 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
531 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
532 if sparsity < 13:
533 sparsity = 1
534 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
535 sparsity += 1
Les Bellf414b3c2021-09-06 11:29:46 +0100536 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100537 for s in sorted(list(strides)):
538 for p in sorted(list(paddings)):
539 for d in sorted(list(dilations)):
540 if (n % sparsity == 0
541 # padding must not exceed the kernel size ?
542 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
543 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
544 # the padded shape must exceed the kernel size
545 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
546 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
547 # the padded shape must exceed the dilation
548 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
549 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
550 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100551 arg_list.append(
552 (
553 "st{}_pad{}_dilat{}".format(
554 "".join([str(x) for x in s]),
555 "".join([str(x) for x in p]),
556 "".join([str(x) for x in d]),
557 ),
558 [s, p, d],
559 )
560 )
561 n += 1
562
Kevin Cheng1533b852021-09-01 12:51:58 -0700563 return arg_list
564
565 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100566 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700567 arg_list = []
568
569 ifm_shape = shapeList[0]
570 filter_shape = shapeList[1]
571
572 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800573 assert len(ifm_shape) == 4
574 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700575
Les Bell7aa69f42021-09-20 10:44:07 +0100576 # Generate comprehensive argument lists
577 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
578 paddings = {x for x in itertools.product(*([p_vals] * 2))}
579 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
580 strides = {x for x in itertools.product(*([s_vals] * 2))}
581 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
582 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700583
Les Bell7aa69f42021-09-20 10:44:07 +0100584 # add some oversize argument values
585 if max(ifm_shape) < 64:
586 bigPadding = 9
587 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
588 bigStride = 8
589 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
590 bigDilation = 7
591 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700592
Les Bell7aa69f42021-09-20 10:44:07 +0100593 # There are too many parameter combinations, so generate them sparsely
594 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
595 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
596 if sparsity < 13:
597 sparsity = 1
598 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
599 sparsity += 1
600 n = 0
601 for s in sorted(list(strides)):
602 for p in sorted(list(paddings)):
603 for d in sorted(list(dilations)):
604 if n % sparsity == 0:
605 # Determine the output shape
606 oh = (
607 ifm_shape[1]
608 - filter_shape[1]
609 - (filter_shape[1] - 1) * (d[0] - 1)
610 + 2 * p[0]
611 ) // s[0] + 1
612 ow = (
613 ifm_shape[2]
614 - filter_shape[2]
615 - (filter_shape[2] - 1) * (d[1] - 1)
616 + 2 * p[1]
617 ) // s[1] + 1
618 os = [ifm_shape[0], oh, ow, filter_shape[0]]
619 arg_list.append(
620 (
621 "st{}_pad{}_dilat{}_os{}".format(
622 "".join([str(x) for x in s]),
623 "".join([str(x) for x in p]),
624 "".join([str(x) for x in d]),
625 "x".join([str(x) for x in os]),
626 ),
627 [s, p, d, os],
628 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800629 )
Les Bell7aa69f42021-09-20 10:44:07 +0100630 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700631
632 return arg_list
633
634 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100635 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700636 arg_list = []
637 rank = len(shapeList[0])
638
Les Bell7ffccce2021-07-28 15:37:02 +0100639 # Exhaustively test combinations of padding on each side of each dimension
640 # - the range of padding values is defined by pad_min and pad_max
641 # - for padding >9, the name format needs to be more distinctive
642 pad_min, pad_max = 0, 1
643 pad_values = [x for x in range(pad_min, pad_max + 1)]
644 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
645 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700646
Les Bell7ffccce2021-07-28 15:37:02 +0100647 for paddings in shape_pad_values:
648 name = "pad"
649 for r in range(rank):
650 before, after = paddings[r]
651 name = f"{name}{before}{after}"
652 arg_list.append((name, [np.array(paddings)]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700653
654 return arg_list
655
656 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100657 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700658 arg_list = []
659
660 shape = shapeList[0]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100661 if error_name != ErrorIf.WrongRank:
662 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700663
Les Bell7aa69f42021-09-20 10:44:07 +0100664 # Generate comprehensive argument lists
665 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
666 paddings = {x for x in itertools.product(*([p_vals] * 4))}
667 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
668 strides = {x for x in itertools.product(*([s_vals] * 2))}
669 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
670 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700671
Les Bell7aa69f42021-09-20 10:44:07 +0100672 # add some oversize argument values
673 bigStride = 7
674 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
675 bigKernel = 6
676 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
677 if max(shape) < 64:
678 # padding must be less than the kernel size
679 bigPadding = bigKernel - 1
680 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700681
Les Bell7aa69f42021-09-20 10:44:07 +0100682 # There are too many parameter combinations, so generate them sparsely
683 sparsity = len(paddings) * len(strides) * len(kernels) // 500 + 1
684 n = 0
685 for s in sorted(list(strides)):
686 for p in sorted(list(paddings)):
687 for k in sorted(list(kernels)):
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100688 # Calculate output height to test for error_if conditions
689 oh = (shape[1] + p[0] + p[1] + s[0] - k[0]) // s[0]
690 ow = (shape[2] + p[2] + p[3] + s[1] - k[1]) // s[1]
691 y = (oh * s[0]) - p[0] - p[1] - s[0] + k[0]
692 x = (ow * s[1]) - p[2] - p[3] - s[1] + k[1]
693
694 if error_name in [ErrorIf.StrideSmallerOne, ErrorIf.KernelSmallerOne, ErrorIf.PadSmallerZero, ErrorIf.PadLargerEqualKernel]:
695 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(testGen, error_name, s, p, k)
696 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
697 arg_list.append(
698 (
699 "st{}_kern{}_pad{}".format(
700 "".join([str(x) for x in sNew]),
701 "".join([str(x) for x in kNew]),
702 "".join([str(x) for x in pNew]),
703 ),
704 [sNew, pNew, kNew],
705 )
706 )
707 elif (n % sparsity == 0
Les Bell7aa69f42021-09-20 10:44:07 +0100708 # padding must not exceed the kernel size
709 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
710 # the padded shape must exceed the kernel size
711 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100712 and y < shape[1] and x < shape[2]
Les Bell7aa69f42021-09-20 10:44:07 +0100713 ):
714 arg_list.append(
715 (
716 "st{}_kern{}_pad{}".format(
717 "".join([str(x) for x in s]),
718 "".join([str(x) for x in k]),
719 "".join([str(x) for x in p]),
720 ),
721 [s, p, k],
722 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800723 )
Les Bell7aa69f42021-09-20 10:44:07 +0100724 n += 1
725
Eric Kunzee5e26762020-10-13 16:11:07 -0700726 return arg_list
727
728 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100729 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700730 arg_list = []
731
732 # Enumerate the output types here
733 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800734 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700735 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800736 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700737 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800738 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700739 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800740 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700741 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800742 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700743 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800744 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700745
746 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800747 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700748
749 return arg_list
750
751 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100752 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700753 arg_list = []
754
755 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100756 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
757 if inDtype == DType.UINT8 and dtype != DType.INT8:
758 # The only output dtype for UINT8 is INT8, skip all other combinations
759 continue
760 if inDtype != DType.INT8 and dtype == DType.UINT8:
761 # The only input dtype for UINT8 is INT8, skip all other combinations
762 continue
763
Kevin Cheng550ccc52021-03-03 11:21:43 -0800764 for scale32 in [False, True]:
765 for double_round in [False, True]:
766 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700767
768 if inDtype == DType.INT48 and scale32:
769 # Illegal condition. Must be scale32=False
770 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100771 if double_round and not scale32:
772 # Illegal condition. ERROR_IF(!scale32 && double_round)
773 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700774
Kevin Cheng550ccc52021-03-03 11:21:43 -0800775 arg_list.append(
776 (
777 "out{}_sc{}_dr{}_pc{}".format(
778 DTypeNames[dtype],
779 int(scale32),
780 int(double_round),
781 int(per_channel),
782 ),
783 [dtype, scale32, double_round, per_channel],
784 )
785 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700786
787 return arg_list
788
Kevin Chengaee1fac2020-11-11 13:54:06 -0800789 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100790 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800791 arg_list = []
792
793 if dtype is DType.INT32:
794 for p in range(testGen.args.num_rand_permutations):
795
796 shift = testGen.randInt(0, 32)
797
Kevin Cheng550ccc52021-03-03 11:21:43 -0800798 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800799 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100800 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800801
802 return arg_list
803
804 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100805 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800806 arg_list = []
807
Kevin Cheng550ccc52021-03-03 11:21:43 -0800808 arg_list.append(("roundTrue", [True]))
809 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800810
811 return arg_list
812
Eric Kunzee5e26762020-10-13 16:11:07 -0700813 # Helper function for reshape. Gets some factors of a larger number.
814 @staticmethod
815 def getFactors(val, start=1):
816 factors = []
817
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100818 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700819 if (val % i) == 0:
820 factors.append(i)
821
822 return factors
823
824 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100825 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700826 arg_list = []
827
828 origShape = shapeList[0]
829
830 totalElements = 1
831 for s in origShape:
832 totalElements *= s
833
834 # This code is NOT fast. Fortunately, the numbers are fairly small.
835 factors = TosaArgGen.getFactors(totalElements)
836
837 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100838 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800839 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700840 continue
841
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100842 found = True
843 # escape_counter breaks while loop if it continues on for too long
844 escape_counter = 0
845 while found:
846 newShape = []
847 # Generate newShape ensuring it isn't a duplicate
848 remainingElements = totalElements
849 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100850 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100851 # pick rank-1 factors
852 newShape.append(shuffledFactors[0])
853 remainingElements = remainingElements // shuffledFactors[0]
854 shuffledFactors = testGen.rng.permutation(
855 TosaArgGen.getFactors(remainingElements)
856 )
857 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700858
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100859 # Toss in a -1 sometimes
860 minusOne = testGen.randInt(0, newRank * 4)
861 if minusOne < newRank:
862 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700863
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100864 # Check for duplicates
865 found = False
866 for name, other_shape in arg_list:
867 if other_shape[0] == newShape:
868 found = True
869 break
870
871 escape_counter += 1
872 if escape_counter >= 100:
873 break
874
875 if not found:
876 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700877
878 return arg_list
879
Eric Kunzee5e26762020-10-13 16:11:07 -0700880 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100881 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700882 arg_list = []
883
884 ifm_shape = shapeList[0]
885
Jeremy Johnsona6185572021-06-21 15:55:35 +0100886 # Get all permutations
887 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700888
Jeremy Johnsona6185572021-06-21 15:55:35 +0100889 # Limit to possible permutations from shape dimension or argument setting
890 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700891
Jeremy Johnsona6185572021-06-21 15:55:35 +0100892 # Get random permutation generator that uses all permutations
893 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700894
Jeremy Johnsona6185572021-06-21 15:55:35 +0100895 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -0700896 arg_list = [
897 ("perm{}".format(p), [random_permutations[p].tolist()])
898 for p in range(limit)
899 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700900 return arg_list
901
902 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100903 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700904 arg_list = []
905
906 ifm_shape = shapeList[0]
907 rank = len(ifm_shape)
908
909 for p in range(testGen.args.num_rand_permutations):
910 begin = []
911 size = []
912
Kevin Cheng550ccc52021-03-03 11:21:43 -0800913 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700914
915 for i in range(rank):
916 if ifm_shape[i] > 1:
917 begin.append(testGen.randInt(0, ifm_shape[i]))
918 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
919
920 # Invalid slice size?
921 if size[i] == 0:
922 valid = False
923 else:
924 begin.append(0)
925 size.append(1)
926
927 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800928 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700929 return arg_list
930
931 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100932 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700933 arg_list = []
934
935 ifm_shape = shapeList[0]
936 rank = len(ifm_shape)
937
938 for p in range(testGen.args.num_rand_permutations):
939
940 # Pick a few random, but small multiple values
941 # because otherwise this has a tendency to generate
942 # enormous tensors
943 multiples = []
944 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +0100945 if ifm_shape[i] > 1000:
946 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
947 multiples.append(1)
948 elif max(ifm_shape) > 1000:
949 multiples.append(2)
950 else:
951 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800952 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700953
954 return arg_list
955
956 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100957 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700958 arg_list = []
959
960 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +0100961 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700962
963 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +0100964 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100965 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +0100966 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800967 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +0100968 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100969 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +0100970 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800971 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800972 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800973 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +0100974 elif error_name == ErrorIf.WrongInputType:
975 # If an incorrect input type is used then we set a 'correct'
976 # output type to avoid other errors
977 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700978 else:
979 continue
980
981 for outputDType in outputDTypeList:
982 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -0700983 # Randomly generate legal output dimensions and shift
984 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +0100985 # A output_dim of 1 will cause offset to exceed allowed range
986 # so minimum value 2 produced below
987 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
988 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
989 output_dims[0] += 1
990 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
991 output_dims[1] += 1
992
Kevin Cheng77d0f762020-11-24 10:26:32 -0800993 in_center_h = (ifm_shape[1] - 1) / 2.0
994 in_center_w = (ifm_shape[2] - 1) / 2.0
995 out_center_h = (output_dims[0] - 1) / 2.0
996 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700997
Kevin Cheng77d0f762020-11-24 10:26:32 -0800998 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
999 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
1000 fp_offset_y = in_center_h - fp_stride_y * out_center_h
1001 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -07001002
Kevin Cheng77d0f762020-11-24 10:26:32 -08001003 if outputDType == DType.FLOAT:
1004 shift = 0
1005 stride = [0, 0]
1006 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001007 stride_fp = [fp_stride_y, fp_stride_x]
1008 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +01001009
1010 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001011 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001012 testGen,
1013 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001014 mode,
1015 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001016 shapeList,
1017 outputDType,
1018 shift,
1019 stride,
1020 stride_fp,
1021 offset,
1022 offset_fp
1023 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001024 else:
1025 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001026
Kevin Cheng550ccc52021-03-03 11:21:43 -08001027 arg_list.append(
1028 (
1029 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001030 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001031 output_dims[0],
1032 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001033 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001034 stride_fp[0],
1035 stride_fp[1],
1036 offset_fp[0],
1037 offset_fp[1],
1038 ),
1039 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001040 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001041 stride,
1042 offset,
1043 shift,
1044 stride_fp,
1045 offset_fp,
1046 output_dims,
1047 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001048 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001049 ],
1050 )
1051 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001052 else:
1053 shift = 11
1054 unit = float(1 << shift)
1055 stride_y = int(round(fp_stride_y * unit))
1056 stride_x = int(round(fp_stride_x * unit))
1057 offset_y = int(round(fp_offset_y * unit))
1058 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001059
Kevin Cheng550ccc52021-03-03 11:21:43 -08001060 while (
Matthew Haddone86fd342021-09-07 16:12:21 +01001061 stride_y >= (16 << shift)
1062 or stride_x >= (16 << shift)
1063 or offset_y >= (16 << shift)
1064 or offset_x >= (16 << shift)
1065 or offset_y <= (-16 << shift)
1066 or offset_x <= (-16 << shift)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001067 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001068 shift = shift - 1
1069 unit = float(1 << shift)
1070 stride_y = int(round(fp_stride_y * unit))
1071 stride_x = int(round(fp_stride_x * unit))
1072 offset_y = int(round(fp_offset_y * unit))
1073 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001074
Kevin Cheng550ccc52021-03-03 11:21:43 -08001075 stride = [stride_y, stride_x]
1076 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001077
1078 stride_fp = [0.0, 0.0]
1079 offset_fp = [0.0, 0.0]
1080
Matthew Haddone86fd342021-09-07 16:12:21 +01001081 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001082 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001083 testGen,
1084 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001085 mode,
1086 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001087 shapeList,
1088 outputDType,
1089 shift,
1090 stride,
1091 stride_fp,
1092 offset,
1093 offset_fp
1094 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001095 else:
1096 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001097
Kevin Cheng550ccc52021-03-03 11:21:43 -08001098 arg_list.append(
1099 (
1100 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001101 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001102 shift,
1103 output_dims[0],
1104 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001105 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001106 stride[0],
1107 stride[1],
1108 offset[0],
1109 offset[1],
1110 ),
1111 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001112 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001113 stride,
1114 offset,
1115 shift,
1116 stride_fp,
1117 offset_fp,
1118 output_dims,
1119 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001120 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001121 ],
1122 )
1123 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001124
1125 return arg_list
1126
Matthew Haddon1c00b712021-10-01 15:51:03 +01001127 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001128 # CondIf generates the condition values here.
1129 # Convert to tensors in the build function, along with the
1130 # then and else blocks
1131 arg_list = []
1132
1133 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001134 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001135
1136 return arg_list
1137
Matthew Haddon1c00b712021-10-01 15:51:03 +01001138 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001139 # While loop: 0 iterations, 1, more than 1
1140 arg_list = []
1141
1142 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001143 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001144
1145 return arg_list
1146
Matthew Haddone86fd342021-09-07 16:12:21 +01001147class TosaErrorIfArgGen:
1148
1149 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001150 def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
Matthew Haddone86fd342021-09-07 16:12:21 +01001151
1152 if outputDType == DType.FLOAT:
1153 if error_name == ErrorIf.StrideSmallerEqualZero:
1154 stride_fp = testGen.rng.random(size=[2]) - 2
1155 elif error_name == ErrorIf.ShiftNotZero:
1156 shift = testGen.rng.integers(1, 5)
1157 elif error_name == ErrorIf.StrideLargerDimension:
1158 shape = shapeList[0]
1159 transform_height = testGen.rng.choice([False, True])
1160 if transform_height:
1161 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1162 else:
1163 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1164 else:
1165 if error_name == ErrorIf.StrideSmallerEqualZero:
1166 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1167 elif error_name == ErrorIf.ShiftSmallerOne:
1168 shift = testGen.rng.integers(-3, 1)
1169 if shift <= 0:
1170 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1171 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1172 else:
1173 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1174 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1175 elif error_name == ErrorIf.ShiftLargerEleven:
1176 shift = np.int16(testGen.rng.integers(12, 15))
1177 elif error_name == ErrorIf.StrideLargerDimension:
1178 shape = shapeList[0]
1179 transform_height = testGen.rng.choice([False, True])
1180 if transform_height:
1181 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1182 else:
1183 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1184 elif error_name == ErrorIf.StrideLargerEqualMax:
1185 stride = [(16 << shift) + 1, (16 << shift) + 1]
1186 elif error_name == ErrorIf.OffsetLargerEqualMax:
1187 offset = [(16 << shift) + 1, (16 << shift) + 1]
1188 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1189 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1190
Matthew Haddon1c00b712021-10-01 15:51:03 +01001191
Matthew Haddon848efb42021-09-09 12:30:53 +01001192 if error_name == ErrorIf.WrongOutputType:
1193 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1194 incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1195 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1196 incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
1197 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1198 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
1199 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1200 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
1201 elif dtype == DType.FLOAT:
1202 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
1203 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001204
Matthew Haddon848efb42021-09-09 12:30:53 +01001205 return shift, stride, stride_fp, offset, offset_fp, outputDType
1206
1207 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001208 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
1209 if (error_name == ErrorIf.StrideSmallerOne
1210 # padding must not exceed the kernel size
1211 and pad[0] < kernel[0] and pad[1] < kernel[0] and pad[2] < kernel[1] and pad[3] < kernel[1]):
1212 wrongStride = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1213 return wrongStride, pad, kernel
1214 elif error_name == ErrorIf.PadSmallerZero:
1215 wrongPad = (testGen.rng.choice([-1, -2, -3]),
1216 testGen.rng.choice([-1, -2, -3]),
1217 testGen.rng.choice([-1, -2, -3]),
1218 testGen.rng.choice([-1, -2, -3]))
1219 return stride, wrongPad, kernel
1220 elif error_name == ErrorIf.KernelSmallerOne:
1221 wrongKernel = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
1222 return stride, pad, wrongKernel
1223 elif error_name == ErrorIf.PadLargerEqualKernel:
1224 wrongPad = (testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1225 testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
1226 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]),
1227 testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]))
1228 return stride, wrongPad, kernel
1229 else:
1230 return None, None, None
1231
1232
1233 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001234 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1235 # Mess up input/output tensors for ERROR_IF checks
1236 if error_name == "WrongInputList":
1237 add_input = testGen.rng.choice([True, False])
1238 if add_input:
1239 input_list.append('eiDummyInput')
1240 else:
1241 input_list = input_list[:-1]
1242 if error_name == "WrongOutputList":
1243 add_output = testGen.rng.choice([True, False])
1244 if add_output:
1245 output_list.append('eiDummyOutput')
1246 else:
1247 output_list = []
1248 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001249
1250class TosaErrorValidator:
1251
Matthew Haddon848efb42021-09-09 12:30:53 +01001252 @staticmethod
1253 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
1254 # Check ERROR_IF statements
1255
1256 for val_fcn in validator_fcns:
1257 val_result = val_fcn(True, **kwargs)
1258
1259 validator_name = val_result['error_name']
1260 error_result = val_result['error_result']
1261 error_reason = val_result['error_reason']
1262
1263 if error_result:
1264 if error_name == validator_name:
1265 serializer.setExpectedReturnCode(2, error_reason)
1266 else:
1267 print(f"Multiple ERROR_IF checks hit \nError required: {error_name}, Error_produced: {validator_name}")
1268 return None # Return None to delete test if wrong ERROR_IF is hit
1269 else:
1270 if error_name == validator_name:
1271 print(f"No ERROR_IF hit for {error_name}")
1272 return None
1273
1274 @staticmethod
1275 def evWrongInputType(check=False, **kwargs):
1276 all_dtypes = (DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1277
1278 # Find the unsupported input data types
1279 assert 'op' in kwargs
1280 op = kwargs['op']
1281 input_dtypes = op['types']
1282 wrong_input_dtypes = list(set(all_dtypes) - set(input_dtypes))
1283
1284 error_name = ErrorIf.WrongInputType
1285 param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
1286 error_result = False
1287 error_reason = "Input data type not supported for this operator"
1288
1289 if check:
1290 input_dtype = kwargs['input_dtype']
1291 if input_dtype not in input_dtypes:
1292 error_result = True
1293
1294 info_dict = {
1295 "error_name": error_name,
1296 "error_result": error_result,
1297 "error_reason": error_reason,
1298 "param_reqs": param_reqs
1299 }
1300 return info_dict
1301
1302 @staticmethod
1303 def evWrongOutputType(check=False, **kwargs):
1304 error_name = ErrorIf.WrongOutputType
1305 param_reqs = {"rank": None, "dtype": None, "shape": None}
1306 error_result = False
1307 error_reason = "Output data type not supported for this configuration of operator"
1308
1309 if check:
1310 input_dtype = kwargs['input_dtype']
1311 output_dtype = kwargs['output_dtype']
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001312 op = kwargs['op']
Matthew Haddon848efb42021-09-09 12:30:53 +01001313
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001314 if op['op'] == Op.RESIZE:
1315 mode = kwargs['mode']
1316 if (
1317 (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
1318 (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
1319 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1320 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1321 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1322 ):
1323 error_result = True
1324 else:
1325 if output_dtype != input_dtype:
1326 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001327
1328 info_dict = {
1329 "error_name": error_name,
1330 "error_result": error_result,
1331 "error_reason": error_reason,
1332 "param_reqs": param_reqs
1333 }
1334 return info_dict
1335
1336 @staticmethod
1337 def evWrongRank(check=False, **kwargs):
1338 all_ranks = (1, 2, 3, 4, 5)
1339
1340 # Make a list of incorrect ranks
1341 assert 'op' in kwargs
1342 op = kwargs['op']
1343 rmin, rmax = op['rank']
1344 rank_range = range(rmin, rmax + 1)
1345 incorrect_ranks = list(set(all_ranks) - set(rank_range))
1346 # Set minimum incorrect rank to 3 to avoid index error
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001347 if op['op'] in [Op.RESIZE]:
Matthew Haddon848efb42021-09-09 12:30:53 +01001348 incorrect_ranks = [3, 5]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001349 elif op['op'] in [Op.AVG_POOL2D, Op.MAX_POOL2D]:
1350 incorrect_ranks = [5]
Matthew Haddon848efb42021-09-09 12:30:53 +01001351
1352 error_name = ErrorIf.WrongRank
1353 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1354 error_result = False
1355 error_reason = "Rank not supported for this operator"
1356
1357 if check:
1358 input_shape = kwargs['input_shape']
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001359 if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4:
Matthew Haddon848efb42021-09-09 12:30:53 +01001360 error_result = True
1361
1362 info_dict = {
1363 "error_name": error_name,
1364 "error_result": error_result,
1365 "error_reason": error_reason,
1366 "param_reqs": param_reqs
1367 }
1368 return info_dict
1369
1370 @staticmethod
1371 def evWrongInputList(check=False, **kwargs):
1372 error_name = ErrorIf.WrongInputList
1373 param_reqs = {"rank": None, "dtype": None, "shape": None}
1374 error_result = False
1375 error_reason = "Op input list does not match expected input"
1376
1377 if check:
1378 op = kwargs['op']
1379 input_list = kwargs['input_list']
1380 num_operands = kwargs['num_operands']
1381 if len(input_list) != num_operands:
1382 error_result = True
1383
1384 info_dict = {
1385 "error_name": error_name,
1386 "error_result": error_result,
1387 "error_reason": error_reason,
1388 "param_reqs": param_reqs
1389 }
1390 return info_dict
1391
1392 @staticmethod
1393 def evWrongOutputList(check=False, **kwargs):
1394 error_name = ErrorIf.WrongOutputList
1395 param_reqs = {"rank": None, "dtype": None, "shape": None}
1396 error_result = False
1397 error_reason = "Op output list does not match expected output"
1398
1399 if check:
1400 output_list = kwargs['output_list']
1401 # Note this will be incorrect if an operator returns more than one output
1402 if len(output_list) != 1:
1403 error_result = True
1404
1405 info_dict = {
1406 "error_name": error_name,
1407 "error_result": error_result,
1408 "error_reason": error_reason,
1409 "param_reqs": param_reqs
1410 }
1411 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01001412
1413 @staticmethod
1414 def evMaxDimExceeded(check=False, **kwargs):
1415 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01001416 param_reqs = {
1417 "rank": [4,4],
1418 "dtype": [DType.INT8],
1419 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
1420 }
Matthew Haddone86fd342021-09-07 16:12:21 +01001421 error_result = False
1422 error_reason = "At least one maximum dimension is larger than 16384"
1423
1424 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001425 input_shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001426 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
1427 if ((input_shape[1] > 16384) or
1428 (input_shape[2] > 16384) or
1429 (output_shape[0] > 16384) or
1430 (output_shape[1] > 16384)):
1431 error_result = True
1432
1433 info_dict = {
1434 "error_name": error_name,
1435 "error_result": error_result,
1436 "error_reason": error_reason,
1437 "param_reqs": param_reqs
1438 }
1439 return info_dict
1440
1441 @staticmethod
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001442 def evBatchMismatch(check=False, **kwargs):
1443 error_name = ErrorIf.BatchMismatch
1444 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1445 error_result = False
1446 error_reason = "Input batch size not equal to output batch size"
1447
1448 assert 'op' in kwargs
1449 op = kwargs['op']
1450 rmin, rmax = op['rank']
1451 rank_range = range(rmin, rmax + 1)
1452
1453 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001454 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001455 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1456
1457 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
1458 error_result = True
1459
1460 info_dict = {
1461 "error_name": error_name,
1462 "error_result": error_result,
1463 "error_reason": error_reason,
1464 "param_reqs": param_reqs
1465 }
1466 return info_dict
1467
1468 @staticmethod
1469 def evChannelMismatch(check=False, **kwargs):
1470 error_name = ErrorIf.ChannelMismatch
1471 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1472 error_result = False
1473 error_reason = "Input channel size not equal to output channel size"
1474
1475 assert 'op' in kwargs
1476 op = kwargs['op']
1477 rmin, rmax = op['rank']
1478 rank_range = range(rmin, rmax + 1)
1479
1480 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001481 input_shape = kwargs['input_shape']
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001482 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1483 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
1484 error_result = True
1485
1486 info_dict = {
1487 "error_name": error_name,
1488 "error_result": error_result,
1489 "error_reason": error_reason,
1490 "param_reqs": param_reqs
1491 }
1492 return info_dict
1493
1494 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001495 def evStrideSmallerEqualZero(check=False, **kwargs):
1496 error_name = ErrorIf.StrideSmallerEqualZero
1497 param_reqs = {"rank": None, "dtype": None, "shape": None}
1498 error_result = False
1499 error_reason = "Stride value smaller than or equal zero"
1500
1501 if check:
1502 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001503 output_dtype = kwargs['output_dtype']
1504 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
1505 stride = kwargs['stride'] # Work around wrong input/output type tests
1506 elif output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001507 stride = kwargs['stride_fp']
Matthew Haddon848efb42021-09-09 12:30:53 +01001508 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1509 stride = kwargs['stride_fp'] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01001510 else:
1511 stride = kwargs['stride']
1512
1513 if min(stride) <= 0:
1514 error_result = True
1515
1516 info_dict = {
1517 "error_name": error_name,
1518 "error_result": error_result,
1519 "error_reason": error_reason,
1520 "param_reqs": param_reqs
1521 }
1522 return info_dict
1523
1524 @staticmethod
1525 def evStrideLargerEqualMax(check=False, **kwargs):
1526 error_name = ErrorIf.StrideLargerEqualMax
1527 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1528 error_result = False
1529 error_reason = "Stride value larger than or equal to maximum value"
1530
1531 if check:
1532 shift = kwargs['shift']
1533 input_dtype = kwargs['input_dtype']
1534 stride = kwargs['stride']
1535 if input_dtype in [DType.INT8, DType.INT16]:
1536 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1537 error_result = True
1538 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1539 error_result = True
1540
1541 info_dict = {
1542 "error_name": error_name,
1543 "error_result": error_result,
1544 "error_reason": error_reason,
1545 "param_reqs": param_reqs
1546 }
1547 return info_dict
1548
1549
1550 @staticmethod
1551 def evStrideLargerDimension(check=False, **kwargs):
1552 error_name = ErrorIf.StrideLargerDimension
1553 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1554 error_result = False
1555 error_reason = "Stride value larger than or equal to H/W dimension"
1556
1557 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001558 shape = kwargs['input_shape']
Matthew Haddone86fd342021-09-07 16:12:21 +01001559 input_dtype = kwargs['input_dtype']
1560 stride = kwargs['stride_fp']
1561
1562 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1563 error_result = True
1564
1565 info_dict = {
1566 "error_name": error_name,
1567 "error_result": error_result,
1568 "error_reason": error_reason,
1569 "param_reqs": param_reqs
1570 }
1571 return info_dict
1572
1573
1574 @staticmethod
1575 def evOffsetSmallerEqualMin(check=False, **kwargs):
1576 error_name = ErrorIf.OffsetSmallerEqualMin
1577 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1578 error_result = False
1579 error_reason = "Offset value smaller than or equal to minimum value"
1580
1581 if check:
1582 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001583 output_dtype = kwargs['output_dtype']
1584 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001585 offset = kwargs['offset_fp']
1586 else:
1587 offset = kwargs['offset']
1588
1589 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
1590 error_result = True
1591 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
1592 error_result = True
1593
1594 info_dict = {
1595 "error_name": error_name,
1596 "error_result": error_result,
1597 "error_reason": error_reason,
1598 "param_reqs": param_reqs
1599 }
1600 return info_dict
1601
1602 @staticmethod
1603 def evOffsetLargerEqualMax(check=False, **kwargs):
1604 error_name = ErrorIf.OffsetLargerEqualMax
1605 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1606 error_result = False
1607 error_reason = "Offset value larger than or equal to maximum value"
1608
1609 if check:
1610 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001611 output_dtype = kwargs['output_dtype']
1612 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001613 offset = kwargs['offset_fp']
1614 else:
1615 offset = kwargs['offset']
1616
1617 if shift >= 0:
1618 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
1619 error_result = True
1620
1621 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
1622 error_result = True
1623 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
1624 error_result = True
1625
1626 info_dict = {
1627 "error_name": error_name,
1628 "error_result": error_result,
1629 "error_reason": error_reason,
1630 "param_reqs": param_reqs
1631 }
1632 return info_dict
1633
1634 @staticmethod
1635 def evShiftNotZero(check=False, **kwargs):
1636 error_name = ErrorIf.ShiftNotZero
1637 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1638 error_result = False
1639 error_reason = "Shift value must be zero for float input"
1640
1641 if check:
1642 shift = kwargs['shift']
1643 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001644 output_dtype = kwargs['output_dtype']
1645 if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
Matthew Haddone86fd342021-09-07 16:12:21 +01001646 error_result = True
1647
1648 info_dict = {
1649 "error_name": error_name,
1650 "error_result": error_result,
1651 "error_reason": error_reason,
1652 "param_reqs": param_reqs
1653 }
1654 return info_dict
1655
1656
1657 @staticmethod
1658 def evShiftSmallerOne(check=False, **kwargs):
1659 error_name = ErrorIf.ShiftSmallerOne
1660 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1661 error_result = False
1662 error_reason = "Shift value smaller than one"
1663
1664 if check:
1665 shift = kwargs['shift']
1666 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001667 output_dtype = kwargs['output_dtype']
1668 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001669 error_result = True
1670
1671 info_dict = {
1672 "error_name": error_name,
1673 "error_result": error_result,
1674 "error_reason": error_reason,
1675 "param_reqs": param_reqs
1676 }
1677 return info_dict
1678
1679 @staticmethod
1680 def evShiftLargerEleven(check=False, **kwargs):
1681 error_name = ErrorIf.ShiftLargerEleven
1682 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1683 error_result = False
1684 error_reason = "Shift value larger than eleven"
1685
1686 if check:
1687 shift = kwargs['shift']
1688 if shift > 11:
1689 error_result = True
1690
1691 info_dict = {
1692 "error_name": error_name,
1693 "error_result": error_result,
1694 "error_reason": error_reason,
1695 "param_reqs": param_reqs
1696 }
1697 return info_dict
1698
1699
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001700 @staticmethod
1701 def evRankMismatch(check=False, **kwargs):
1702 error_name = ErrorIf.RankMismatch
1703 param_reqs = {"rank": None, "dtype": None, "shape": None}
1704 error_result = False
1705 error_reason = "Input Rank does not match output rank"
1706
1707 if check:
1708 input1_shape = kwargs['input1'].shape
1709 input2_shape = kwargs['input2'].shape
1710 output_shape = kwargs['result_tensor'].shape
1711 if (len(input1_shape) != len(output_shape)) or (len(input2_shape) != len(output_shape)):
1712 error_result = True
1713
1714 info_dict = {
1715 "error_name": error_name,
1716 "error_result": error_result,
1717 "error_reason": error_reason,
1718 "param_reqs": param_reqs
1719 }
1720 return info_dict
1721
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001722 @staticmethod
1723 def evInputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001724 op = kwargs['op']
1725 inputDtypes = op['types'].copy()
1726 if DType.INT8 in inputDtypes:
1727 inputDtypes.remove(DType.INT8)
1728 if DType.UINT8 in inputDtypes:
1729 inputDtypes.remove(DType.UINT8)
1730
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001731 error_name = ErrorIf.InputZeroPointNotZero
1732 param_reqs = {
1733 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001734 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001735 "shape": None
1736 }
1737 error_result = False
1738 error_reason = "Input DType not INT8 and zero point not 0"
1739
1740 if check:
1741 input_dtype = kwargs['input_dtype']
1742 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1743 qinfo = kwargs['qinfo'].ints
1744 input_zero_point = qinfo[0][1]
1745 if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
1746 error_result = True
1747
1748 info_dict = {
1749 "error_name": error_name,
1750 "error_result": error_result,
1751 "error_reason": error_reason,
1752 "param_reqs": param_reqs
1753 }
1754 return info_dict
1755
1756
1757 @staticmethod
1758 def evOutputZeroPointNotZero(check=False, **kwargs):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001759 op = kwargs['op']
1760 inputDtypes = op['types'].copy()
1761 if DType.INT8 in inputDtypes:
1762 inputDtypes.remove(DType.INT8)
1763 if DType.UINT8 in inputDtypes:
1764 inputDtypes.remove(DType.UINT8)
1765
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001766 error_name = ErrorIf.OutputZeroPointNotZero
1767 param_reqs = {
1768 "rank": None,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001769 "dtype": inputDtypes,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001770 "shape": None
1771 }
1772 error_result = False
1773 error_reason = "Output DType not INT8 and zero point not 0"
1774
1775 if check:
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001776 input_dtype = kwargs['input_dtype']
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001777 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1778 qinfo = kwargs['qinfo'].ints
1779 output_zero_point = qinfo[1][1]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001780 if input_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001781 error_result = True
1782
1783 info_dict = {
1784 "error_name": error_name,
1785 "error_result": error_result,
1786 "error_reason": error_reason,
1787 "param_reqs": param_reqs
1788 }
1789 return info_dict
1790
Matthew Haddond6ce7252021-09-29 15:35:44 +01001791 @staticmethod
1792 def evAxisSmallerZero(check=False, **kwargs):
1793 error_name = ErrorIf.AxisSmallerZero
1794 param_reqs = {"rank": None, "dtype": None, "shape": None}
1795 error_result = False
1796 error_reason = "Axis smaller than zero"
1797
1798 if check:
1799 axis = kwargs['axis']
1800 if axis < 0:
1801 error_result = True
1802
1803 info_dict = {
1804 "error_name": error_name,
1805 "error_result": error_result,
1806 "error_reason": error_reason,
1807 "param_reqs": param_reqs
1808 }
1809 return info_dict
1810
1811
1812 @staticmethod
1813 def evAxisLargerRank(check=False, **kwargs):
1814 error_name = ErrorIf.AxisLargerRank
1815 param_reqs = {"rank": None, "dtype": None, "shape": None}
1816 error_result = False
1817 error_reason = "Axis larger than rank"
1818
1819 if check:
1820 axis = kwargs['axis']
1821 shape = kwargs['input_shape']
1822 if axis > len(shape):
1823 error_result = True
1824
1825 info_dict = {
1826 "error_name": error_name,
1827 "error_result": error_result,
1828 "error_reason": error_reason,
1829 "param_reqs": param_reqs
1830 }
1831 return info_dict
1832
1833
1834 @staticmethod
1835 def evShapeOfAxisNotOne(check=False, **kwargs):
1836 error_name = ErrorIf.ShapeOfAxisNotOne
1837 param_reqs = {"rank": None, "dtype": None, "shape": None}
1838 error_result = False
1839 error_reason = "shape[axis] is not equal to 1"
1840
1841 if check:
1842 axis = kwargs['axis']
1843 shape = kwargs['output_shape']
1844 if (0 <= axis < len(shape)) and shape[axis] != 1:
1845 error_result = True
1846
1847 info_dict = {
1848 "error_name": error_name,
1849 "error_result": error_result,
1850 "error_reason": error_reason,
1851 "param_reqs": param_reqs
1852 }
1853 return info_dict
1854
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001855
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001856 @staticmethod
1857 def evPadSmallerZero(check=False, **kwargs):
1858 error_name = ErrorIf.PadSmallerZero
1859 param_reqs = {"rank": None, "dtype": None, "shape": None}
1860 error_result = False
1861 error_reason = "At least one pad is smaller than zero"
1862
1863 if check:
1864 pad = kwargs['pad']
1865 if min(pad) < 0:
1866 error_result = True
1867
1868 info_dict = {
1869 "error_name": error_name,
1870 "error_result": error_result,
1871 "error_reason": error_reason,
1872 "param_reqs": param_reqs
1873 }
1874 return info_dict
1875
1876
1877 @staticmethod
1878 def evPadLargerEqualKernel(check=False, **kwargs):
1879 error_name = ErrorIf.PadLargerEqualKernel
1880 param_reqs = {"rank": None, "dtype": None, "shape": None}
1881 error_result = False
1882 error_reason = "At least one pad is larger than kernel dimension"
1883
1884 if check:
1885 pad = kwargs['pad']
1886 kernel = kwargs['kernel']
1887 if min(pad) > 0 and min(kernel) > 1:
1888 if pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]:
1889 error_result = True
1890
1891 info_dict = {
1892 "error_name": error_name,
1893 "error_result": error_result,
1894 "error_reason": error_reason,
1895 "param_reqs": param_reqs
1896 }
1897 return info_dict
1898
1899 @staticmethod
1900 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1901 error_name = ErrorIf.PoolingOutputShapeMismatch
1902 param_reqs = {"rank": None, "dtype": None, "shape": None}
1903 error_result = False
1904 error_reason = "Mismatch between output shape provided and expected output shape"
1905
1906 if check:
1907 pad = kwargs['pad']
1908 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1909
1910 kernel = kwargs['kernel']
1911 kernel_y, kernel_x = kernel[0], kernel[1]
1912
1913 input_shape = kwargs['input_shape']
1914 IH, IW = input_shape[1], input_shape[2]
1915
1916 output_shape = kwargs['output_shape']
1917 OH, OW = output_shape[1], output_shape[2]
1918
1919 stride = kwargs['stride']
1920 stride_y, stride_x = stride[0], stride[1]
1921
1922 # calculate correct height, width dimensions
1923 if stride_x != 0 and stride_y != 0:
1924 y_correct = (IH + pad_top + pad_bottom + stride_y - kernel_y) // stride_y
1925 x_correct = (IW + pad_left + pad_right + stride_x - kernel_x) // stride_x
1926
1927 # ensure parameters are valid
1928 params_valid = (min(kernel) >= 1 and min(stride) >= 1 and min(pad) >= 0
1929 and not (pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]))
1930
1931 if params_valid and (OH != y_correct or OW != x_correct):
1932 error_result = True
1933
1934 info_dict = {
1935 "error_name": error_name,
1936 "error_result": error_result,
1937 "error_reason": error_reason,
1938 "param_reqs": param_reqs
1939 }
1940 return info_dict
1941
1942
1943 @staticmethod
1944 def evKernelSmallerOne(check=False, **kwargs):
1945 error_name = ErrorIf.KernelSmallerOne
1946 param_reqs = {"rank": None, "dtype": None, "shape": None}
1947 error_result = False
1948 error_reason = "At least one kernel dimension is smaller than zero"
1949
1950 if check:
1951 kernel = kwargs['kernel']
1952 if min(kernel) < 1:
1953 error_result = True
1954
1955 info_dict = {
1956 "error_name": error_name,
1957 "error_result": error_result,
1958 "error_reason": error_reason,
1959 "param_reqs": param_reqs
1960 }
1961 return info_dict
1962
1963 @staticmethod
1964 def evStrideSmallerOne(check=False, **kwargs):
1965 error_name = ErrorIf.StrideSmallerOne
1966 param_reqs = {"rank": None, "dtype": None, "shape": None}
1967 error_result = False
1968 error_reason = "At least one stride dimension is smaller than zero"
1969
1970 if check:
1971 stride = kwargs['stride']
1972 if min(stride) < 1:
1973 error_result = True
1974
1975 info_dict = {
1976 "error_name": error_name,
1977 "error_result": error_result,
1978 "error_reason": error_reason,
1979 "param_reqs": param_reqs
1980 }
1981 return info_dict
1982
1983
1984
Matthew Haddonb724efc2021-08-25 16:40:29 +01001985class TosaInvalidValidator:
1986
1987 @staticmethod
1988 def ivWrongDataTypeOrModeResize(**kwargs):
1989 input_dtype = kwargs["input_dtype"]
1990 args = kwargs["args"]
1991 mode = args[0]
1992 stride = args[1]
1993 stride_fp = args[4]
1994 output_dtype = args[8]
1995
1996 if mode == ResizeMode.BILINEAR:
1997 # Invalid output data type / Invalid input datatype
1998 return (
1999 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
2000 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
2001 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
2002 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2003 )
2004 elif mode == ResizeMode.NEAREST:
2005 # Invalid output data type / Invalid input datatype
2006 return (
2007 (input_dtype != output_dtype) or
2008 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2009 )
2010 else:
2011 # Invalid resize mode
2012 return True
2013
2014 @staticmethod
2015 def ivBadStride(**kwargs):
2016 input_dtype = kwargs["input_dtype"]
2017 args = kwargs["args"]
2018 stride_x = args[1][0]
2019 stride_y = args[1][1]
2020 stride_fp_x = args[4][0]
2021 stride_fp_y = args[4][1]
2022
2023 if input_dtype == DType.FLOAT:
2024 if stride_fp_x <= 0 or stride_fp_y <= 0:
2025 # Negative or zero stride
2026 return True
2027 else:
2028 if stride_x <= 0 or stride_y <= 0:
2029 # Negative or zero stride
2030 return True
2031 return False
2032
2033
Matthew Haddonb724efc2021-08-25 16:40:29 +01002034 @staticmethod
2035 def ivHeightWidthSmallerZero(**kwargs):
2036 opName = kwargs['opName']
2037
2038 inputShapes = kwargs['shapeList']
2039 input = inputShapes[0]
2040 if not opName.endswith("pool2d"):
2041 filter = inputShapes[1]
2042
2043 args = kwargs['args']
2044 strides = args[0]
2045 padding = args[1]
2046 dilations = args[2]
2047 if opName.endswith("pool2d"):
2048 kernel = args[2]
2049
2050 if opName.startswith('conv2d'):
2051 h = (
2052 input[1]
2053 - filter[1]
2054 - (filter[1] - 1) * (dilations[0] - 1)
2055 + padding[0]
2056 + padding[1]
2057 ) // strides[0] + 1
2058
2059 w = (
2060 input[2]
2061 - filter[2]
2062 - (filter[2] - 1) * (dilations[1] - 1)
2063 + padding[2]
2064 + padding[3]
2065 ) // strides[1] + 1
2066 elif opName.startswith("depthwise_conv2d"):
2067 h = (
2068 input[1]
2069 - filter[0]
2070 - (filter[0] - 1) * (dilations[0] - 1)
2071 + padding[0]
2072 + padding[1]
2073 ) // strides[0] + 1
2074
2075 w = (
2076 input[2]
2077 - filter[1]
2078 - (filter[1] - 1) * (dilations[1] - 1)
2079 + padding[2]
2080 + padding[3]
2081 ) // strides[1] + 1
2082 elif opName.endswith("pool2d"):
2083 h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
2084 w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
2085 else:
2086 assert False, "Unrecognized Op"
2087
2088 if h <= 0 or w <= 0:
2089 # Invalid parameter combination
2090 return True
2091 return False
2092
2093 @staticmethod
2094 def ivNonPositiveOutputShape(**kwargs):
2095 args = kwargs['args']
2096 output_shape = args[3]
2097 if output_shape[1] <= 0 or output_shape[2] <= 0:
2098 # Negative output shape
2099 return True
2100 return False
2101
2102
Kevin Cheng550ccc52021-03-03 11:21:43 -08002103
Eric Kunzee5e26762020-10-13 16:11:07 -07002104class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002105 # Maximum rank of tensor supported by test generator.
2106 TOSA_TENSOR_MAX_RANK = 6
2107
Eric Kunzee5e26762020-10-13 16:11:07 -07002108 def __init__(self, args):
2109 self.args = args
2110 self.basePath = args.output_dir
2111 self.random_seed = args.random_seed
2112 self.ser = None
2113 self.rng = np.random.default_rng(self.random_seed)
2114 self.createDynamicOpLists()
2115 self.initOpListDefaults()
2116 self.quantGen = TosaQuantGen()
2117 # Force makeShape to do a specific starting shape
2118 self.targetted_shape = None
2119
2120 def createSerializer(self, opName, testPath):
2121 self.testPath = os.path.join(opName, testPath)
2122
2123 fullPath = os.path.join(self.basePath, self.testPath)
2124 os.makedirs(fullPath, exist_ok=True)
2125 self.ser = ts.TosaSerializer(fullPath)
2126
2127 def getSerializer(self):
2128 return self.ser
2129
2130 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002131 with open(
2132 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
2133 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07002134 fd.write(self.ser.serialize())
2135
Kevin Cheng550ccc52021-03-03 11:21:43 -08002136 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
2137 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07002138
Matthew Haddon74567092021-07-16 15:38:20 +01002139 def resetRNG(self, seed=None):
2140 if seed == None:
2141 seed = self.random_seed + 1
2142 self.rng = np.random.default_rng(seed)
2143
Eric Kunzee5e26762020-10-13 16:11:07 -07002144 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07002145 if dtype == DType.BOOL:
2146 np_dt = np.bool
2147 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07002148 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07002149 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07002150 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002151 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002152 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
2153 elif dtype == DType.UINT8:
2154 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002155 elif dtype == DType.INT16:
2156 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
2157 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002158 return np.int32(
2159 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
2160 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002161 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002162 return np.int64(
2163 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
2164 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002165 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002166 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002167 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002168 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002169
Kevin Cheng989cb052021-04-28 16:29:44 -07002170 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07002171 placeholders = []
2172
Kevin Cheng989cb052021-04-28 16:29:44 -07002173 assert len(shape_list) == len(dtype_list)
2174
2175 for idx, shape in enumerate(shape_list):
2176 arr = self.getRandTensor(shape, dtype_list[idx])
2177 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002178
2179 return placeholders
2180
Kevin Cheng989cb052021-04-28 16:29:44 -07002181 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07002182 consts = []
2183
Kevin Cheng989cb052021-04-28 16:29:44 -07002184 assert len(shape_list) == len(dtype_list)
2185
2186 for idx, shape in enumerate(shape_list):
2187 arr = self.getRandTensor(shape, dtype_list[idx])
2188 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002189
2190 return consts
2191
2192 def makeShape(self, rank):
2193 if self.targetted_shape:
2194 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002195 return np.int32(
2196 self.rng.integers(
2197 low=self.args.tensor_shape_range[0],
2198 high=self.args.tensor_shape_range[1],
2199 size=rank,
2200 )
2201 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002202
2203 def setTargetShape(self, shape):
2204 self.targetted_shape = shape
2205
2206 def randInt(self, low=0, high=256):
2207 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
2208
2209 def getRandNumberDType(self, dtype):
2210 if dtype == DType.FLOAT:
2211 return self.rng.random()
2212 elif dtype == DType.BOOL:
2213 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07002214 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07002215 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07002216 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07002217 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002218 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07002219 elif dtype == DType.INT16:
2220 low, high = (-32768, 32768)
2221 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002222 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07002223 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002224 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07002225 # Special size
2226 return np.int64(self.rng.integers(low, high, size=1))[0]
2227 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002228 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002229
2230 return np.int32(self.rng.integers(low, high, size=1))[0]
2231
2232 def shapeStr(self, shape):
2233
2234 sStr = []
2235 # Convert to strings
2236 for i in shape:
2237 sStr.append(str(i))
2238
Kevin Cheng550ccc52021-03-03 11:21:43 -08002239 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002240
2241 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07002242 if isinstance(t, list):
2243 assert len(t) >= 2
2244 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002245 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002246 if t == DType.BOOL:
2247 return "b"
2248 elif t == DType.INT4:
2249 return "i4"
2250 elif t == DType.INT8:
2251 return "i8"
2252 elif t == DType.UINT8:
2253 return "u8"
2254 elif t == DType.INT16:
2255 return "i16"
2256 elif t == DType.INT32:
2257 return "i32"
2258 elif t == DType.INT48:
2259 return "i48"
2260 elif t == DType.FLOAT:
2261 return "float"
2262 else:
2263 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002264
2265 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002266 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08002267 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07002268 return 4
2269 elif t == DType.INT8:
2270 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08002271 elif t == DType.UINT8:
2272 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07002273 elif t == DType.INT16:
2274 return 16
2275 elif t == DType.INT32:
2276 return 32
2277 elif t == DType.INT48:
2278 return 48
2279 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002280 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002281
2282 # Argument generators
2283 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
2284 # Where the string descriptor is used to generate the test name and
2285 # The build_fcn_arg_list is expanded and passed to the operator test
2286 # build function
2287
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002288 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
2289 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
2290
Matthew Haddon848efb42021-09-09 12:30:53 +01002291 # build_placeholder returns an int, ABS/other ops does not
2292 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002293 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
2294 return result_tens
2295 elif op['op'] == Op.IDENTITY:
2296 self.ser.addOperator(op['op'], a.name, result_tens.name, None, qinfo)
2297 return result_tens
2298
2299 # Ensure new output type has correct qinfo
2300 if error_name == ErrorIf.WrongOutputType:
2301 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
2302 qinfo = ts.TosaSerializerQuantInfo()
2303 qinfo.UnaryQuantInfo(
2304 TosaQuantGen.getQinfo(self, a.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
2305 )
2306
2307 # Invalidate Input/Output list for error if checks.
2308 input_list = [a.name]
2309 output_list = [result_tens.name]
2310 pCount, cCount = op["operands"]
2311 num_operands = pCount + cCount
2312 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2313
2314 TosaErrorValidator.evValidateErrorIfs(
2315 self.ser,
2316 validator_fcns,
2317 error_name,
2318 op=op,
2319 input_dtype=a.dtype,
2320 output_dtype=result_tens.dtype,
2321 qinfo = qinfo,
2322 result_tensor = result_tens,
2323 input_list=input_list,
2324 output_list=output_list,
2325 num_operands=num_operands,
2326 )
2327
2328 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002329 return result_tens
2330
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002331 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
2332 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
2333
2334
2335 # Invalidate Input/Output list for error if checks.
2336 input_list = [a.name, b.name]
2337 output_list = [result_tens.name]
2338 pCount, cCount = op["operands"]
2339 num_operands = pCount + cCount
2340 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2341
2342 TosaErrorValidator.evValidateErrorIfs(
2343 self.ser,
2344 validator_fcns,
2345 error_name,
2346 op=op,
2347 input1 = a,
2348 input2 = b,
2349 input_dtype = a.dtype,
2350 output_dtype = result_tens.dtype,
2351 result_tensor = result_tens,
2352 input_list=input_list,
2353 output_list=output_list,
2354 num_operands=num_operands,
2355 )
2356
2357 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07002358 return result_tens
2359
2360 def build_binary_nonbroadcast(self, op, a, b):
2361 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002362 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002363 return result_tens
2364
Kevin Chengaee1fac2020-11-11 13:54:06 -08002365 def build_arithmetic_right_shift(self, op, a, b, round):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002366 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002367
2368 attr = ts.TosaSerializerAttribute()
2369 attr.ArithmeticRightShiftAttribute(round)
2370
Matthew Haddon848efb42021-09-09 12:30:53 +01002371 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002372 return result_tens
2373
2374 def build_mul(self, op, a, b, shift):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002375 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Eric Kunzee5e26762020-10-13 16:11:07 -07002376
2377 # Special for multiply:
2378 # Force the result to INT32 for INT types
2379 if a.dtype != DType.FLOAT:
2380 result_tens.setDtype(DType.INT32)
2381
Kevin Chengaee1fac2020-11-11 13:54:06 -08002382 attr = ts.TosaSerializerAttribute()
2383 attr.MulAttribute(shift)
2384
Matthew Haddon848efb42021-09-09 12:30:53 +01002385 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002386 return result_tens
2387
2388 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002389 # Constant size depending on type, random values
2390 if a.dtype == DType.INT16:
Kevin Chengacb550f2021-06-29 15:32:19 -07002391 table_dtype = DType.INT16
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002392 table_arr = self.getRandTensor([513], table_dtype)
2393 else:
2394 assert a.dtype == DType.INT8
2395 table_dtype = DType.INT8
2396 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002397
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002398 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
2399 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01002400 self.ser.addOperator(op['op'], [a.name, table_tens.name], [result_tens.name], None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002401
2402 return result_tens
2403
2404 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07002405 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002406 self.ser.addOperator(op['op'], [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002407 return result_tens
2408
2409 def build_comparison(self, op, a, b):
2410 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002411 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002412 return result_tens
2413
2414 def build_argmax(self, op, a, axis):
2415 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
2416
2417 attr = ts.TosaSerializerAttribute()
2418 attr.AxisAttribute(axis)
2419
Matthew Haddon848efb42021-09-09 12:30:53 +01002420 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002421 return result_tens
2422
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002423 def build_pool2d(self, op, input, stride, pad, kernel, validator_fcns=None, error_name=None, qinfo=None):
2424 result_tens = OutputShaper.pool2dOp(self.ser, self.rng, input, kernel, stride, pad, error_name)
2425
2426 # Ensure new output type has correct qinfo
2427 if error_name == ErrorIf.WrongInputType:
2428 if input.dtype not in [DType.INT8, DType.UINT8]:
2429 qinfo = ts.TosaSerializerQuantInfo()
2430 qinfo.UnaryQuantInfo(
2431 TosaQuantGen.getQinfo(self, input.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
2432 )
2433
2434 # Invalidate Input/Output list for error if checks.
2435 input_list = [input.name]
2436 output_list = [result_tens.name]
2437 pCount, cCount = op["operands"]
2438 num_operands = pCount + cCount
2439 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2440
2441 TosaErrorValidator.evValidateErrorIfs(
2442 self.ser,
2443 validator_fcns,
2444 error_name,
2445 op=op,
2446 input_shape=input.shape,
2447 input_dtype=input.dtype,
2448 output_shape=result_tens.shape,
2449 output_dtype=result_tens.dtype,
2450 kernel=kernel,
2451 stride=stride,
2452 pad=pad,
2453 qinfo = qinfo,
2454 result_tensor = result_tens,
2455 input_list=input_list,
2456 output_list=output_list,
2457 num_operands=num_operands,
2458 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002459
2460 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002461 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07002462
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002463 self.ser.addOperator(op['op'], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002464 return result_tens
2465
2466 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002467 assert len(padding) == 4
2468 result_tens = OutputShaper.conv2dOp(
2469 self.ser, ifm, filter, strides, padding, dilations
2470 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002471
2472 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002473 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07002474
Kevin Cheng550ccc52021-03-03 11:21:43 -08002475 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002476 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002477 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002478 return result_tens
2479
Kevin Cheng1533b852021-09-01 12:51:58 -07002480 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
2481 assert len(padding) == 6
2482 result_tens = OutputShaper.conv3dOp(
2483 self.ser, ifm, filter, strides, padding, dilations
2484 )
2485
2486 attr = ts.TosaSerializerAttribute()
2487 attr.ConvAttribute(padding, strides, dilations)
2488
2489 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002490 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng1533b852021-09-01 12:51:58 -07002491 )
2492 return result_tens
2493
Kevin Cheng550ccc52021-03-03 11:21:43 -08002494 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07002495 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002496 ):
2497 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07002498 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
2499
2500 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002501 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002502
Kevin Cheng550ccc52021-03-03 11:21:43 -08002503 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002504 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002505 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002506 return result_tens
2507
Kevin Cheng550ccc52021-03-03 11:21:43 -08002508 def build_depthwise_conv2d(
2509 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
2510 ):
2511 result_tens = OutputShaper.depthwiseConv2dOp(
2512 self.ser, ifm, filter, strides, padding, dilations
2513 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002514
2515 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002516 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07002517
Kevin Cheng550ccc52021-03-03 11:21:43 -08002518 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002519 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002520 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002521 return result_tens
2522
2523 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
2524 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
2525
Kevin Cheng550ccc52021-03-03 11:21:43 -08002526 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002527 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002528 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002529 return result_tens
2530
2531 def build_matmul(self, op, a, b, qinfo):
2532 result_tens = OutputShaper.matmulOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002533 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002534 return result_tens
2535
Matthew Haddond6ce7252021-09-29 15:35:44 +01002536 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
2537 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
2538
2539 # Invalidate Input/Output list for error if checks.
2540 input_list = [a.name]
2541 output_list = [result_tens.name]
2542 pCount, cCount = op["operands"]
2543 num_operands = pCount + cCount
2544 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2545
2546 TosaErrorValidator.evValidateErrorIfs(
2547 self.ser,
2548 validator_fcns,
2549 error_name,
2550 op=op,
2551 axis = axis,
2552 input_shape = a.shape,
2553 output_shape = result_tens.shape,
2554 input_dtype = a.dtype,
2555 output_dtype = result_tens.dtype,
2556 result_tensor = result_tens,
2557 input_list=input_list,
2558 output_list=output_list,
2559 num_operands=num_operands,
2560 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002561
2562 attr = ts.TosaSerializerAttribute()
2563 attr.AxisAttribute(axis)
2564
Matthew Haddond6ce7252021-09-29 15:35:44 +01002565 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002566 return result_tens
2567
2568 def build_clamp(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002569 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002570
2571 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01002572 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07002573
2574 if a.dtype == DType.FLOAT:
2575 attr.ClampAttribute(0, 0, min(v), max(v))
2576 else:
2577 attr.ClampAttribute(min(v), max(v), 0, 0)
2578
Matthew Haddon848efb42021-09-09 12:30:53 +01002579 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002580 return result_tens
2581
2582 def build_leaky_relu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002583 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002584 attr = ts.TosaSerializerAttribute()
2585
2586 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
2587
Matthew Haddon848efb42021-09-09 12:30:53 +01002588 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002589 return result_tens
2590
2591 # Needs an additional type/input
2592 def build_prelu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002593 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002594
Matthew Haddon848efb42021-09-09 12:30:53 +01002595 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002596 return result_tens
2597
Eric Kunzee5e26762020-10-13 16:11:07 -07002598 def build_sigmoid(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002599 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01002600 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002601 return result_tens
2602
2603 def build_tanh(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002604 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01002605 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002606 return result_tens
2607
Matthew Haddon818ab902021-07-27 09:12:49 +01002608 def build_concat(self, op, *a):
Kevin Cheng93a16282021-08-31 16:14:03 -07002609 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01002610
2611 # To store variable length list of input tensors we need to store axis along with it
2612 axis = a[-1]
2613 a = a[:-1]
2614
2615 result_tens = OutputShaper.concatOp(self.ser, axis, *a)
Eric Kunzee5e26762020-10-13 16:11:07 -07002616
2617 attr = ts.TosaSerializerAttribute()
2618 attr.AxisAttribute(axis)
2619
Matthew Haddon818ab902021-07-27 09:12:49 +01002620 input_tensor_names = []
2621 for tensor in a:
2622 input_tensor_names.append(tensor.name)
2623
Matthew Haddon848efb42021-09-09 12:30:53 +01002624 self.ser.addOperator(op['op'], input_tensor_names, [result_tens.name], attr)
2625 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002626
2627 def build_pad(self, op, a, padding, qinfo):
2628 result_tens = OutputShaper.padOp(self.ser, a, padding)
2629
2630 # Need to turn the padding array into a TOSA tensor here.
2631 # This is one of the few tensor operands that does not get
2632 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08002633 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07002634
Kevin Cheng550ccc52021-03-03 11:21:43 -08002635 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002636 op['op'], [a.name, padding_tens.name], [result_tens.name], None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002637 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002638 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002639
2640 def build_reshape(self, op, a, newShape):
2641 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
2642
2643 attr = ts.TosaSerializerAttribute()
2644 attr.ReshapeAttribute(newShape)
2645
Matthew Haddon848efb42021-09-09 12:30:53 +01002646 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002647 return result_tens
2648
2649 def build_reverse(self, op, a, axis):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002650 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002651
2652 attr = ts.TosaSerializerAttribute()
2653 attr.AxisAttribute(axis)
2654
Matthew Haddon848efb42021-09-09 12:30:53 +01002655 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002656 return result_tens
2657
2658 def build_transpose(self, op, a, perms):
2659 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
2660
Kevin Cheng550ccc52021-03-03 11:21:43 -08002661 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07002662
Matthew Haddon848efb42021-09-09 12:30:53 +01002663 self.ser.addOperator(op['op'], [a.name, perms_tens.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002664 return result_tens
2665
2666 def build_slice(self, op, a, begin, size):
2667 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
2668
2669 attr = ts.TosaSerializerAttribute()
2670 attr.SliceAttribute(begin, size)
2671
Matthew Haddon848efb42021-09-09 12:30:53 +01002672 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002673 return result_tens
2674
2675 def build_tile(self, op, a, multiples):
2676 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
2677
2678 attr = ts.TosaSerializerAttribute()
2679 attr.TileAttribute(multiples)
2680
Matthew Haddon848efb42021-09-09 12:30:53 +01002681 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002682 return result_tens
2683
Kevin Cheng77d0f762020-11-24 10:26:32 -08002684 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07002685
2686 # Create a new indicies tensor
2687 # here with data that doesn't exceed the dimensions of the values tensor
2688
Kevin Cheng550ccc52021-03-03 11:21:43 -08002689 K = values.shape[1] # K
2690 W = self.randInt(
2691 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
2692 ) # W
2693 indicies_arr = np.int32(
2694 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
2695 ) # (N, W)
2696 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002697
Kevin Cheng77d0f762020-11-24 10:26:32 -08002698 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07002699
Matthew Haddon848efb42021-09-09 12:30:53 +01002700 self.ser.addOperator(op['op'], [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002701
2702 return result_tens
2703
Kevin Cheng77d0f762020-11-24 10:26:32 -08002704 def build_scatter(self, op, values_in, input):
2705
2706 # Create a new indicies tensor
2707 # here with data that doesn't exceed the dimensions of the values_in tensor
2708
Kevin Cheng550ccc52021-03-03 11:21:43 -08002709 K = values_in.shape[1] # K
2710 W = input.shape[1] # W
2711 indicies_arr = np.int32(
2712 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
2713 ) # (N, W)
2714 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002715
2716 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
2717
Kevin Cheng550ccc52021-03-03 11:21:43 -08002718 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002719 op['op'], [values_in.name, indicies.name, input.name], [result_tens.name]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002720 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08002721
2722 return result_tens
2723
Matthew Haddon848efb42021-09-09 12:30:53 +01002724
Kevin Cheng550ccc52021-03-03 11:21:43 -08002725 def build_resize(
2726 self,
2727 op,
2728 input,
2729 mode,
2730 stride,
2731 offset,
2732 shift,
2733 stride_fp,
2734 offset_fp,
2735 output_dims,
2736 input_dtype,
2737 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01002738 validator_fcns,
2739 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002740 ):
2741 result_tens = OutputShaper.resizeOp(
2742 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002743 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002744 input,
2745 mode,
2746 stride,
2747 offset,
2748 shift,
2749 stride_fp,
2750 offset_fp,
2751 output_dims,
2752 input_dtype,
2753 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01002754 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08002755 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002756
Matthew Haddon848efb42021-09-09 12:30:53 +01002757 # Invalidate Input/Output list for error if checks.
2758 input_list = [input.name]
2759 output_list = [result_tens.name]
2760 pCount, cCount = op["operands"]
2761 num_operands = pCount + cCount
2762 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
Matthew Haddone86fd342021-09-07 16:12:21 +01002763
Matthew Haddon848efb42021-09-09 12:30:53 +01002764 TosaErrorValidator.evValidateErrorIfs(
2765 self.ser,
2766 validator_fcns,
2767 error_name,
2768 op=op,
2769 mode=mode,
2770 shift=shift,
2771 input_dtype=input_dtype,
2772 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002773 input_shape=input.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002774 output_shape=output_dims,
2775 offset=offset,
2776 offset_fp=offset_fp,
2777 stride=stride,
2778 stride_fp=stride_fp,
2779 input_list=input_list,
2780 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002781 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01002782 num_operands=num_operands,
2783 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002784
Eric Kunzee5e26762020-10-13 16:11:07 -07002785 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08002786
Kevin Cheng550ccc52021-03-03 11:21:43 -08002787 attr.ResizeAttribute(
2788 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
2789 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002790
Matthew Haddon848efb42021-09-09 12:30:53 +01002791 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002792 return result_tens
2793
2794 def build_identityn(self, op, val, val2):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002795 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, None)
2796 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, None)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002797 self.ser.addOperator(
2798 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2799 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002800 return result_tens
2801
Kevin Cheng17e92022021-10-01 14:33:33 -07002802 def build_const(self, op, val):
2803 self.ser.addOutputTensor(val)
2804 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07002805
2806 # Type Conversion
2807 def build_cast(self, op, val, out_dtype):
2808 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01002809 self.ser.addOperator(op['op'], [val.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002810 return result_tens
2811
2812 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
2813 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
2814
2815 if per_channel:
2816 nc = val.shape[-1]
2817 else:
2818 nc = 1
2819
2820 in_type_width = self.typeWidth(val.dtype)
2821 out_type_width = self.typeWidth(out_dtype)
2822
Kevin Cheng3a478572021-01-22 17:21:02 -08002823 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002824 input_zp = self.randInt(-128, 128)
2825 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002826 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002827 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07002828 in_type_width = in_type_width + 1
2829 else:
2830 input_zp = 0
2831
Kevin Cheng3a478572021-01-22 17:21:02 -08002832 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002833 output_zp = self.randInt(-128, 128)
2834 out_type_width = out_type_width + 1
2835 elif out_dtype == DType.UINT8:
2836 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07002837 out_type_width = out_type_width + 1
2838 else:
2839 output_zp = 0
2840
2841 # Calculate scale based on:
2842 # scale = a *(2^output_width)/(2^input_width))
2843
2844 a = np.float32(self.rng.random(size=[nc]))
2845 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2846
2847 if scale32:
2848 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002849 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002850 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2851 else:
2852 # Cap the scaling at 2^15 - 1 for scale16
2853 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2854
Kevin Cheng550ccc52021-03-03 11:21:43 -08002855 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002856
2857 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2858 shift_arr = np.int32(np.zeros(shape=[nc]))
2859
2860 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002861 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2862 scale_arr[i], scale32
2863 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002864
Kevin Cheng550ccc52021-03-03 11:21:43 -08002865 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07002866
2867 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002868 attr.RescaleAttribute(
2869 input_zp,
2870 output_zp,
2871 multiplier_arr,
2872 shift_arr,
2873 scale32,
2874 double_round,
2875 per_channel,
2876 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002877
Matthew Haddon848efb42021-09-09 12:30:53 +01002878 self.ser.addOperator(op['op'], [val.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002879 return result_tens
2880
2881 def build_cond_if_const(self, op, then_tens, else_tens, cond):
2882 # For cond_if with constants, we're supplied with then/else tensors that we ignore
2883 # (except for the generated shap) and the condition. Build Then/Else blocks
2884 # and fill them with const nodes for the body.
2885
2886 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002887 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07002888
2889 # Make then/else tensors
2890 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01002891 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2892 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002893
2894 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002895 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002896
2897 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002898 then_block = "THEN_BLOCK"
2899 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002900 attr = ts.TosaSerializerAttribute()
2901 attr.CondIfAttribute(then_block, else_block)
2902
2903 # Finally, build the op and the two blocks
Matthew Haddon848efb42021-09-09 12:30:53 +01002904 self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002905
2906 self.ser.startBasicBlock(then_block)
2907 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002908 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002909 self.ser.addOutputTensor(then_tens)
2910
2911 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002912 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002913 self.ser.addOutputTensor(else_tens)
2914
2915 return result_tens
2916
2917 def build_cond_if_binary(self, op, a, b, cond):
2918 # For cond_if with a binary op in the then/else blocks, take a and b and
2919 # alternately add or subtract them based on the condition
2920
2921 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002922 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07002923
Kevin Cheng550ccc52021-03-03 11:21:43 -08002924 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002925
2926 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002927 then_block = "THEN_BLOCK"
2928 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002929 attr = ts.TosaSerializerAttribute()
2930 attr.CondIfAttribute(then_block, else_block)
2931
2932 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002933 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002934 op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002935 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002936
2937 self.ser.startBasicBlock(then_block)
2938 self.ser.addInputTensor(a)
2939 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002940 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002941 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
2942
2943 self.ser.startBasicBlock(else_block)
2944 self.ser.addInputTensor(a)
2945 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002946 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002947 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
2948
2949 return result_tens
2950
2951 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002952 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002953
Kevin Cheng550ccc52021-03-03 11:21:43 -08002954 cond_block = "COND_BLOCK"
2955 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002956
2957 attr = ts.TosaSerializerAttribute()
2958 attr.WhileLoopAttribute(cond_block, body_block)
2959
2960 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002961 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002962 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002963 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002964
2965 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002966 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2967 a_out = self.ser.addIntermediate(a.shape, a.dtype)
2968 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002969
2970 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002971 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002972 op['op'],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002973 [iter.name, a.name, acc.name],
2974 [iter_out.name, a_out.name, acc_out.name],
2975 attr,
2976 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002977 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002978
2979 # COND block (input: iter, output: cond_tens )
2980 self.ser.startBasicBlock(cond_block)
2981 self.ser.addInputTensor(iter)
2982 self.ser.addInputTensor(a)
2983 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002984 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
2985 cond_tens = self.ser.addOutput([], DType.BOOL)
2986 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002987
2988 # BODY block (input: a, acc, iter, output: a, acc, iter)
2989 # Note that local intermediate tensors need to be declared here for the outputs
2990 self.ser.startBasicBlock(body_block)
2991 self.ser.addInputTensor(iter)
2992 self.ser.addInputTensor(a)
2993 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002994 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
2995 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2996 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002997 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2998 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2999 self.ser.addOutputTensor(iter_body_out)
3000 self.ser.addOutputTensor(a)
3001 self.ser.addOutputTensor(acc_body_out)
3002
3003 return acc_out
3004
Matthew Haddon1c00b712021-10-01 15:51:03 +01003005 def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
3006 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
3007 default_test_rank_range = range(1, 5)
3008 if not shapeFilter:
3009 shapeFilter = [None]
3010
3011 # Calculate the filters based on what is requested and what the operator allows
3012 rmin, rmax = op["rank"]
3013 if rankFilter is not None:
3014 cleanRankFilter = []
3015 # Ensure rankFilter values are allowed by operator
3016 for rank in rankFilter:
3017 if rank >= rmin and rank <= rmax:
3018 cleanRankFilter.append(rank)
3019 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01003020 # Ensure default behaviour is bounded by default range or by operator,
3021 # whichever is the smaller range of ranks.
3022 opRankRange = range(rmin, rmax + 1)
3023 cleanRankFilter = opRankRange if len(opRankRange) <= len(default_test_rank_range) else default_test_rank_range
Matthew Haddon1c00b712021-10-01 15:51:03 +01003024 else:
3025 cleanRankFilter = range(rmin, rmax + 1)
3026
3027 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003028
Matthew Haddon1c00b712021-10-01 15:51:03 +01003029 if dtypeFilter is not None:
3030 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01003031 # Create list of operator dtypes filtered by requested dtypes
3032 for dtype in dtypes:
3033 if dtype in dtypeFilter or (isinstance(dtype, list) and dtype[0] in dtypeFilter):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003034 cleanDtypeFilter.append(dtype)
3035 else:
3036 cleanDtypeFilter = dtypes
3037
3038 if testType == 'positive':
3039 filterDict = {
3040 'shapeFilter': shapeFilter,
3041 'rankFilter': cleanRankFilter,
3042 'dtypeFilter': cleanDtypeFilter
3043 }
3044 return filterDict
3045 elif testType == 'negative':
3046 validator_info = validator(check=False, op=op)
3047 error_arguments = validator_info['param_reqs']
3048
3049 #Set parameters as required
3050 if error_arguments['rank'] != None:
3051 rankFilter = error_arguments['rank']
3052 else:
3053 rankFilter = cleanRankFilter
3054
3055 if error_arguments['dtype'] != None:
3056 dtypeFilter = error_arguments['dtype']
3057 else:
3058 dtypeFilter = cleanDtypeFilter
3059
3060 if error_arguments['shape'] != None:
3061 shapeFilter = error_arguments['shape']
3062 else:
3063 shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small
3064
3065 filterDict = {
3066 'shapeFilter': shapeFilter,
3067 'rankFilter': rankFilter,
3068 'dtypeFilter': dtypeFilter
3069 }
3070 return filterDict
3071
3072
Kevin Cheng550ccc52021-03-03 11:21:43 -08003073 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01003074 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08003075 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003076
3077 try:
3078 op = self.TOSA_OP_LIST[opName]
3079 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003080 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003081
3082 # Initialize a new random number generator
3083 self.rng = np.random.default_rng(self.random_seed)
3084
Kevin Cheng550ccc52021-03-03 11:21:43 -08003085 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003086
Eric Kunzee5e26762020-10-13 16:11:07 -07003087 # Test list consists of a tuple of:
3088 # (opName, testNameStr, dtype, shapeList, argumentsList)
3089 testList = []
Matthew Haddon1c00b712021-10-01 15:51:03 +01003090 if testType == 'negative' and "error_if_validators" in op:
3091 error_if_validators = op["error_if_validators"]
3092 else:
3093 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07003094
Matthew Haddon1c00b712021-10-01 15:51:03 +01003095 for validator in error_if_validators:
3096 if validator is not None:
3097 error_name = validator(check=False, op=op)['error_name']
3098 #print("error_name: ", error_name)
3099 else:
3100 error_name = None
3101
3102 filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator)
3103 cleanRankFilter = filterDict['rankFilter']
3104 cleanDtypeFilter = filterDict['dtypeFilter']
3105 cleanShapeFilter = filterDict['shapeFilter']
3106 #print(f"Filters: S {shapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
3107
3108 for r in cleanRankFilter:
Kevin Cheng1533b852021-09-01 12:51:58 -07003109 if opName.startswith("conv3d"):
3110 assert r == 5, "conv3d test must have input rank == 5"
Matthew Haddon1c00b712021-10-01 15:51:03 +01003111 for t in cleanDtypeFilter:
3112 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01003113 # Filter out by rank
3114 if shape is not None and len(shape) != r:
3115 continue
Matthew Haddon74567092021-07-16 15:38:20 +01003116 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003117 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003118
Matthew Haddon74567092021-07-16 15:38:20 +01003119 shapeStr = self.shapeStr(shapeList[0])
3120 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07003121
Matthew Haddon74567092021-07-16 15:38:20 +01003122 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
3123 argList = []
3124 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01003125 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003126 else:
Matthew Haddon74567092021-07-16 15:38:20 +01003127 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07003128
Matthew Haddon74567092021-07-16 15:38:20 +01003129 for argStr, args in argList:
Matthew Haddon1c00b712021-10-01 15:51:03 +01003130 if testType == 'positive':
3131 if argStr:
3132 testStr = "{}_{}_{}_{}".format(
3133 opName, shapeStr, typeStr, argStr
3134 )
3135 else:
3136 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
3137 elif testType == 'negative':
Matthew Haddone86fd342021-09-07 16:12:21 +01003138 if argStr:
3139 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
3140 opName, error_name, shapeStr, typeStr, argStr
3141 )
3142 else:
3143 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003144
3145 testList.append((opName, testStr, t, error_name, shapeList, args))
3146
3147 if testType == 'positive':
3148 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
3149 if "invalid_test_validators" in op:
3150 invalid_test_validators = op["invalid_test_validators"]
3151 clean_testList = []
3152 for test in testList:
3153 for validator_fcn in invalid_test_validators:
3154 remove_test = False
3155 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
3156 remove_test = True
3157 if not remove_test:
3158 clean_testList.append(test)
3159 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07003160
3161 return testList
3162
Matthew Haddone86fd342021-09-07 16:12:21 +01003163
3164 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07003165 try:
3166 op = self.TOSA_OP_LIST[opName]
3167 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003168 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003169
3170 # Create a serializer
3171 self.createSerializer(opName, testStr)
3172
Kevin Cheng550ccc52021-03-03 11:21:43 -08003173 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003174 if "error_if_validators" in op:
3175 error_if_validators = op["error_if_validators"]
3176 else:
3177 error_if_validators = None
3178
Kevin Cheng550ccc52021-03-03 11:21:43 -08003179 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003180 num_operands = pCount + cCount
3181
3182 if isinstance(dtype_or_dtypeList, list):
3183 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07003184 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01003185 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003186 else:
3187 dtypeList = [dtype_or_dtypeList] * (num_operands)
3188
Kevin Cheng93a16282021-08-31 16:14:03 -07003189 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01003190 assert (
3191 len(shapeList) == num_operands
3192 ), "shapeList length {} must match number of operands {}".format(
3193 len(shapeList), num_operands
3194 )
3195 assert (
3196 len(dtypeList) == num_operands
3197 ), "dtypeList length {} must match number of operands {}".format(
3198 len(dtypeList), num_operands
3199 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003200
3201 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003202 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003203 except KeyError:
3204 qgen = None
3205
3206 # Build the random tensor operands and the test
3207 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08003208
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003209 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003210
3211 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003212 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003213 else:
3214 qinfo = None
3215
3216 try:
3217 if error_if_validators is None:
3218 if qinfo is not None:
3219 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
3220 else:
3221 resultName = build_fcn(self, op, *tens, *testArgs)
3222 else:
3223 if qinfo is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003224 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name, qinfo)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003225 else:
3226 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name)
3227 except TypeError as e:
3228 print(
3229 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
3230 build_fcn, tens, testArgs
3231 )
3232 )
3233 raise e
3234
3235 if resultName is None:
3236 print("Invalid ERROR_IF tests created")
3237
3238 # Save the serialized test
3239 self.serialize("test")
3240
3241
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003242 def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003243 pCount, cCount = op["operands"]
3244
3245 tens = []
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003246 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name == None:
Jeremy Johnsonef509a42021-09-07 13:59:47 +01003247 # Make sure the operation does not cause value saturation - where
3248 # the number wraps due to limited number of bits to store the answer
3249 assert (
3250 pCount == 2 and cCount == 0
3251 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01003252 placeholders = []
3253 add = (op["op"] == Op.ADD)
3254 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
3255 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
3256 if add:
3257 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
3258 else:
3259 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
3260
3261 # Work out the saturation limits
3262 max_i32 = (1 << 31)-1
3263 min_i32 = -(1 << 31)
3264 max_arr = np.full(shapeList[1], max_i32)
3265 min_arr = np.full(shapeList[1], min_i32)
3266
3267 # Find how much values exceed the maximum/minimums
3268 sat_max_arr = np.maximum(res_arr - max_arr, 0)
3269 sat_min_arr = np.minimum(res_arr - min_arr, 0)
3270
3271 if not add:
3272 # Swap saturation values and negate values as we need to perform opposite operations
3273 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
3274
3275 # Create new array of unsaturated values by clipping values as needed
3276 b_unsat_arr = b_arr
3277 if (sat_max_arr != 0).any():
3278 # Clip values that cause saturation
3279 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
3280 # Reduce axes in unsaturated tensor to match original tensor
3281 for axis, dim in enumerate(b_arr.shape):
3282 if dim != b_unsat_arr.shape[axis]:
3283 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
3284 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
3285
3286 if (sat_min_arr != 0).any():
3287 # Clip values that cause saturation
3288 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
3289 # Reduce axes in unsaturated tensor to match original tensor
3290 for axis, dim in enumerate(b_arr.shape):
3291 if dim != b_unsat_arr.shape[axis]:
3292 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
3293 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
3294
3295 placeholders.append(
3296 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
3297 )
3298 placeholders.append(
3299 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
3300 )
3301
3302 tens.extend(placeholders)
3303 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
3304 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003305 assert (
3306 pCount == 2 and cCount == 0
3307 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08003308
3309 placeholders = []
3310 for idx, shape in enumerate(shapeList[:]):
3311 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07003312 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08003313 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07003314 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08003315 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07003316 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08003317 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
3318 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003319 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08003320 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003321 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07003322 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08003323
3324 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01003325 elif op["op"] == Op.SELECT:
3326 # Set datatype of condition tensor to boolean
3327 dtypeList[0] = DType.BOOL
3328 tens.extend(
3329 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
3330 )
3331 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003332 elif op["op"] == Op.INTDIV and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003333 assert (
3334 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01003335 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003336
3337 placeholders = []
3338
Matthew Haddon459443c2021-08-23 16:43:13 +01003339 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003340 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07003341 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003342 while True:
3343 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
3344 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
3345
3346 if (divisor_arr == 0).any():
3347 continue
3348
Kevin Cheng47315e12021-05-13 17:41:28 -07003349 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003350 continue
3351
3352 break
3353
3354 placeholders.append(
3355 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
3356 )
3357 placeholders.append(
3358 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
3359 )
3360
3361 tens.extend(placeholders)
3362 elif op["op"] == Op.MUL:
3363 assert (
3364 pCount == 2 and cCount == 0
3365 ), "Op.MUL must have 2 placeholders, 0 consts"
3366
3367 if dtypeList[0] == DType.FLOAT:
3368 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
3369 else:
3370 placeholders = []
3371
3372 # Make sure multiply result in int32 range
3373 shift = testArgs[0]
3374 if dtypeList[0] == DType.INT8:
3375 num_bits = 8
3376 elif dtypeList[0] == DType.INT16:
3377 num_bits = 16
3378 elif dtypeList[0] == DType.INT32:
3379 num_bits = 32
3380 else:
3381 raise Exception("OpMul: invalid input dtype")
3382
3383 for idx, shape in enumerate(shapeList[:]):
3384 low = -(2 ** (num_bits - 1))
3385 high = (2 ** (num_bits - 1)) - 1
3386
3387 a_arr = np.int32(
3388 self.rng.integers(low=low, high=high, size=shapeList[0])
3389 )
3390 b_arr = np.int32(
3391 self.rng.integers(low=low, high=high, size=shapeList[1])
3392 )
3393
3394 i = 0
3395 while True:
3396
3397 a_arr_64 = a_arr.astype(np.int64)
3398 b_arr_64 = b_arr.astype(np.int64)
3399
3400 if shift > 0:
3401 rounding = 1 << (shift - 1)
3402 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
3403 else:
3404 result_arr = a_arr_64 * b_arr_64
3405
3406 if (result_arr > -(2 ** 31)).all() and (
3407 result_arr <= ((2 ** 31) - 1)
3408 ).all():
3409 break
3410
3411 i = i + 1
3412 a_arr = a_arr // 2
3413 b_arr = b_arr // 2
3414
3415 placeholders.append(
3416 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
3417 )
3418 placeholders.append(
3419 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
3420 )
3421
3422 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01003423 elif op["op"] == Op.CONCAT:
3424 count = len(shapeList) - self.args.num_const_inputs_concat
3425 if count < 1:
3426 count = 1
3427 if self.args.num_const_inputs_concat == 0:
3428 count = len(shapeList)
3429
3430 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
3431 tens.extend(
3432 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
3433 )
3434 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08003435 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07003436 tens.extend(
3437 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
3438 )
3439 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07003440
Matthew Haddon1c00b712021-10-01 15:51:03 +01003441 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07003442
3443 def createDynamicOpLists(self):
3444
3445 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07003446 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003447
Kevin Cheng1533b852021-09-01 12:51:58 -07003448 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003449 testName = "conv2d_{}x{}".format(k[0], k[1])
3450 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3451 self.TOSA_OP_LIST[testName]["filter"] = k
3452 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003453
Kevin Cheng550ccc52021-03-03 11:21:43 -08003454 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3455 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3456 "depthwise_conv2d_TEMPLATE"
3457 ].copy()
3458 self.TOSA_OP_LIST[testName]["filter"] = k
3459 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003460
Kevin Cheng550ccc52021-03-03 11:21:43 -08003461 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3462 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3463 "transpose_conv2d_TEMPLATE"
3464 ].copy()
3465 self.TOSA_OP_LIST[testName]["filter"] = k
3466 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003467
Kevin Cheng1533b852021-09-01 12:51:58 -07003468 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3469 for k in KERNELS_3D:
3470 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3471 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3472 self.TOSA_OP_LIST[testName]["filter"] = k
3473 self.TOSA_OP_LIST[testName]["template"] = False
3474
Eric Kunzee5e26762020-10-13 16:11:07 -07003475 # Delete any templates after having created any dynamic ops
3476 # This is a two-pass operation because it's bad practice to delete
3477 # keys from dictionaries while iterating
3478 keyList = []
3479 for k in self.TOSA_OP_LIST:
3480 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003481 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07003482 keyList.append(k)
3483 continue
3484 except KeyError:
3485 pass
3486
3487 for k in keyList:
3488 del self.TOSA_OP_LIST[k]
3489
3490 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003491 """Fill in default fields for ops if they aren't already specified.
3492 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003493 for op in self.TOSA_OP_LIST:
3494
3495 # Required fields
3496 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003497 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003498 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003499 raise Exception(
3500 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3501 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003502
3503 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003504 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003505 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003506 raise Exception(
3507 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3508 op
3509 )
3510 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003511
3512 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003513 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003514 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003515 raise Exception(
3516 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3517 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003518
3519 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003520 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003521 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003522 raise Exception(
3523 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3524 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003525
3526 # Put in default rank range, if missing
3527 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003528 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003529 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003530 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003531
3532 # Tensor operator list
3533 # 'op': op name
3534 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003535 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3536 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003537 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3538 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08003539 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07003540
Kevin Cheng550ccc52021-03-03 11:21:43 -08003541 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
3542 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003543
Kevin Cheng550ccc52021-03-03 11:21:43 -08003544 TYPE_BOOL = [DType.BOOL]
3545 TYPE_FI32 = [DType.FLOAT, DType.INT32]
3546 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
3547 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003548
Kevin Cheng550ccc52021-03-03 11:21:43 -08003549 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07003550
Kevin Cheng1533b852021-09-01 12:51:58 -07003551 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003552 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003553 [DType.INT8, DType.INT8, DType.INT32],
3554 [DType.INT16, DType.INT8, DType.INT48],
3555 DType.FLOAT,
3556 ]
3557
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003558 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003559
3560 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003561 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003562 "argmax": {
3563 "op": Op.ARGMAX,
3564 "operands": (1, 0),
3565 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3566 "types": TYPE_NARROW_INT_FP,
3567 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003568 "avg_pool2d": {
3569 "op": Op.AVG_POOL2D,
3570 "operands": (1, 0),
3571 "rank": (4, 4),
3572 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
3573 "qgen": TosaQuantGen.qgUnary,
3574 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003575 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
3576 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
3577 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
3578 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
3579 TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08003580 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003581 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003582 "conv2d_TEMPLATE": {
3583 "op": Op.CONV2D,
3584 "operands": (1, 2),
3585 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01003586 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003587 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003588 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003589 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003590 "template": True,
3591 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003592 # Templated operator. Filled in by createDynamicOpLists
3593 "conv3d_TEMPLATE": {
3594 "op": Op.CONV3D,
3595 "operands": (1, 2),
3596 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01003597 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07003598 "qgen": TosaQuantGen.qgConv,
3599 "types": TYPE_CONV,
3600 "template": True,
3601 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003602 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003603 "depthwise_conv2d_TEMPLATE": {
3604 "op": Op.DEPTHWISE_CONV2D,
3605 "operands": (1, 2),
3606 "filter": [1, 1],
3607 "rank": (4, 4),
3608 "build_fcn": (
3609 build_depthwise_conv2d,
3610 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01003611 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003612 ),
3613 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003614 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003615 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003616 "template": True,
3617 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003618 "fully_connected": {
3619 "op": Op.FULLY_CONNECTED,
3620 "operands": (1, 2),
3621 "rank": (2, 2),
3622 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
3623 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003624 "types": TYPE_CONV,
Jared Smolens573ecd42021-03-04 15:24:10 -08003625 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003626 "matmul": {
3627 "op": Op.MATMUL,
3628 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003629 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08003630 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
3631 "qgen": TosaQuantGen.qgMatmul,
3632 "types": TYPE_NARROW_INT_FP,
3633 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003634 "max_pool2d": {
3635 "op": Op.MAX_POOL2D,
3636 "operands": (1, 0),
3637 "rank": (4, 4),
3638 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
3639 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003640 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
3641 "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
3642 TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
3643 TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
Jared Smolens573ecd42021-03-04 15:24:10 -08003644 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003645 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003646 "transpose_conv2d_TEMPLATE": {
3647 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003648 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003649 "rank": (4, 4),
3650 "build_fcn": (
3651 build_transpose_conv2d,
3652 TosaTensorGen.tgTransposeConv2D,
3653 TosaArgGen.agTransposeConv2D,
3654 ),
3655 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003656 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003657 "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003658 "template": True,
3659 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003660 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003661 "clamp": {
3662 "op": Op.CLAMP,
3663 "operands": (1, 0),
3664 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
3665 "types": TYPE_NARROW_INT_FP,
3666 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003667 "sigmoid": {
3668 "op": Op.SIGMOID,
3669 "operands": (1, 0),
3670 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
3671 "types": TYPE_FP,
3672 },
3673 "tanh": {
3674 "op": Op.TANH,
3675 "operands": (1, 0),
3676 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
3677 "types": TYPE_FP,
3678 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003679 # Elementwise Binary Operators
3680 "add": {
3681 "op": Op.ADD,
3682 "operands": (2, 0),
3683 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3684 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003685 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3686 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003687 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003688 "arithmetic_right_shift": {
3689 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3690 "operands": (2, 0),
3691 "build_fcn": (
3692 build_arithmetic_right_shift,
3693 TosaTensorGen.tgBroadcastFuzz,
3694 TosaArgGen.agArithmeticRightShift,
3695 ),
3696 "types": TYPE_INT,
3697 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003698 "bitwise_and": {
3699 "op": Op.BITWISE_AND,
3700 "operands": (2, 0),
3701 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3702 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003703 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3704 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003705 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003706 "bitwise_or": {
3707 "op": Op.BITWISE_OR,
3708 "operands": (2, 0),
3709 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3710 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003711 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3712 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003713 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003714 "bitwise_xor": {
3715 "op": Op.BITWISE_XOR,
3716 "operands": (2, 0),
3717 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3718 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003719 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3720 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003721 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003722 "intdiv": {
3723 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003724 "operands": (2, 0),
3725 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3726 "types": [DType.INT32],
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003727 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3728 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003729 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003730 "logical_and": {
3731 "op": Op.LOGICAL_AND,
3732 "operands": (2, 0),
3733 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3734 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003735 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3736 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003737 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003738 "logical_left_shift": {
3739 "op": Op.LOGICAL_LEFT_SHIFT,
3740 "operands": (2, 0),
3741 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3742 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003743 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3744 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003745 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003746 "logical_right_shift": {
3747 "op": Op.LOGICAL_RIGHT_SHIFT,
3748 "operands": (2, 0),
3749 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3750 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003751 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3752 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003753 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003754 "logical_or": {
3755 "op": Op.LOGICAL_OR,
3756 "operands": (2, 0),
3757 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3758 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003759 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3760 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003761 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003762 "logical_xor": {
3763 "op": Op.LOGICAL_XOR,
3764 "operands": (2, 0),
3765 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3766 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003767 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3768 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003769 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003770 "maximum": {
3771 "op": Op.MAXIMUM,
3772 "operands": (2, 0),
3773 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3774 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003775 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3776 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003777 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003778 "minimum": {
3779 "op": Op.MINIMUM,
3780 "operands": (2, 0),
3781 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3782 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003783 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3784 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003785 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003786 "mul": {
3787 "op": Op.MUL,
3788 "operands": (2, 0),
3789 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
3790 "types": TYPE_INT_FP,
3791 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003792 "pow": {
3793 "op": Op.POW,
3794 "operands": (2, 0),
3795 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
3796 "types": TYPE_FP,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003797 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3798 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003799 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003800 "sub": {
3801 "op": Op.SUB,
3802 "operands": (2, 0),
3803 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3804 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003805 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3806 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003807 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003808 "table": {
3809 "op": Op.TABLE,
3810 # Use the automatic generation functions to create the input array
3811 # but create the table tensor in the build function, as it may be
3812 # a different type from the input
3813 "operands": (1, 0),
3814 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003815 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08003816 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003817 # Elementwise Unary operators
3818 "abs": {
3819 "op": Op.ABS,
3820 "operands": (1, 0),
3821 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3822 "types": TYPE_FI32,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003823 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3824 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003825 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003826 "bitwise_not": {
3827 "op": Op.BITWISE_NOT,
3828 "operands": (1, 0),
3829 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3830 "types": TYPE_INT,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003831 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3832 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003833 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003834 "ceil": {
3835 "op": Op.CEIL,
3836 "operands": (1, 0),
3837 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3838 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003839 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3840 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003841 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003842 "clz": {
3843 "op": Op.CLZ,
3844 "operands": (1, 0),
3845 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3846 "types": [DType.INT32],
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003847 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3848 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003849 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003850 "exp": {
3851 "op": Op.EXP,
3852 "operands": (1, 0),
3853 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3854 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003855 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3856 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003857 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003858 "floor": {
3859 "op": Op.FLOOR,
3860 "operands": (1, 0),
3861 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3862 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003863 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3864 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003865 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003866 "log": {
3867 "op": Op.LOG,
3868 "operands": (1, 0),
3869 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3870 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003871 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3872 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003873 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003874 "logical_not": {
3875 "op": Op.LOGICAL_NOT,
3876 "operands": (1, 0),
3877 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3878 "types": TYPE_BOOL,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003879 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3880 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003881 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003882 "negate": {
3883 "op": Op.NEGATE,
3884 "operands": (1, 0),
3885 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3886 "qgen": TosaQuantGen.qgUnary,
3887 "types": TYPE_INT_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003888 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
3889 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
3890 TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003891 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003892 "reciprocal": {
3893 "op": Op.RECIPROCAL,
3894 "operands": (1, 0),
3895 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3896 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003897 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3898 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003899 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003900 "rsqrt": {
3901 "op": Op.RSQRT,
3902 "operands": (1, 0),
3903 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3904 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003905 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3906 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003907 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003908 # Elementwise Ternary operators
3909 "select": {
3910 "op": Op.SELECT,
3911 "operands": (3, 0),
3912 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
3913 "types": TYPE_FIB,
3914 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003915 # Comparison operators
3916 "equal": {
3917 "op": Op.EQUAL,
3918 "operands": (2, 0),
3919 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3920 "types": TYPE_FI32,
3921 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003922 "greater_equal": {
3923 "op": Op.GREATER_EQUAL,
3924 "operands": (2, 0),
3925 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3926 "types": TYPE_FI32,
3927 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003928 "greater": {
3929 "op": Op.GREATER,
3930 "operands": (2, 0),
3931 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3932 "types": TYPE_FI32,
3933 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003934 # Reduction operators
3935 "reduce_all": {
3936 "op": Op.REDUCE_ALL,
3937 "operands": (1, 0),
3938 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3939 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01003940 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
3941 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
3942 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003943 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003944 "reduce_any": {
3945 "op": Op.REDUCE_ANY,
3946 "operands": (1, 0),
3947 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3948 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01003949 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
3950 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
3951 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003952 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003953 "reduce_max": {
3954 "op": Op.REDUCE_MAX,
3955 "operands": (1, 0),
3956 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3957 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01003958 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
3959 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
3960 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003961 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003962 "reduce_min": {
3963 "op": Op.REDUCE_MAX,
3964 "operands": (1, 0),
3965 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3966 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01003967 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
3968 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
3969 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003970 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003971 "reduce_product": {
3972 "op": Op.REDUCE_PRODUCT,
3973 "operands": (1, 0),
3974 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3975 "types": TYPE_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01003976 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
3977 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
3978 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003979 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003980 "reduce_sum": {
3981 "op": Op.REDUCE_SUM,
3982 "operands": (1, 0),
3983 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3984 "types": TYPE_FI32,
Matthew Haddond6ce7252021-09-29 15:35:44 +01003985 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
3986 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
3987 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003988 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003989 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003990 "concat": {
3991 "op": Op.CONCAT,
3992 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01003993 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003994 "types": TYPE_FIB,
3995 },
3996 "pad": {
3997 "op": Op.PAD,
3998 "operands": (1, 0),
3999 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
4000 "qgen": TosaQuantGen.qgPad,
4001 "types": TYPE_FIB,
4002 },
4003 "reshape": {
4004 "op": Op.RESHAPE,
4005 "operands": (1, 0),
4006 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
4007 "types": TYPE_FIB,
4008 },
4009 "reverse": {
4010 "op": Op.REVERSE,
4011 "operands": (1, 0),
4012 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
4013 "types": TYPE_FIB,
4014 },
4015 "slice": {
4016 "op": Op.SLICE,
4017 "operands": (1, 0),
4018 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
4019 "types": TYPE_FIB,
4020 },
4021 "tile": {
4022 "op": Op.TILE,
4023 "operands": (1, 0),
4024 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
4025 "types": TYPE_FIB,
4026 },
4027 "transpose": {
4028 "op": Op.TRANSPOSE,
4029 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01004030 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004031 "build_fcn": (
4032 build_transpose,
4033 TosaTensorGen.tgBasic,
4034 TosaArgGen.agTranspose,
4035 ),
4036 "types": TYPE_FIB,
4037 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004038 # Data nodes
4039 "const": {
4040 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004041 "operands": (0, 1),
4042 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08004043 "types": TYPE_FIB,
4044 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004045 "identity": {
4046 "op": Op.IDENTITY,
4047 "operands": (1, 0),
4048 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
4049 "types": TYPE_FIB,
4050 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004051 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004052 "gather": {
4053 "op": Op.GATHER,
4054 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4055 "operands": (1, 0),
4056 "rank": (3, 3),
4057 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
4058 "types": TYPE_INT_FP,
4059 },
4060 "scatter": {
4061 "op": Op.SCATTER,
4062 # Only specify 'values_in' tensor here.
4063 #'indices' and 'input' are generated in op building stage
4064 "operands": (2, 0),
4065 "rank": (3, 3),
4066 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
4067 "types": TYPE_INT_FP,
4068 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004069 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004070 "resize": {
4071 "op": Op.RESIZE,
4072 "operands": (1, 0),
4073 "rank": (4, 4),
4074 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
4075 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01004076 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
4077 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
4078 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
Matthew Haddon848efb42021-09-09 12:30:53 +01004079 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004080 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
4081 TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004082 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004083 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004084 "cast": {
4085 "op": Op.CAST,
4086 "operands": (1, 0),
4087 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
4088 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
4089 },
4090 "rescale": {
4091 "op": Op.RESCALE,
4092 "operands": (1, 0),
4093 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01004094 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Kevin Cheng550ccc52021-03-03 11:21:43 -08004095 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004096 # Custom
4097 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004098 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004099 # Two varients of cond_if, one that generates one of two constant tensors (no
4100 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4101 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004102 "cond_if_const": {
4103 "op": Op.COND_IF,
4104 "operands": (0, 2),
4105 "build_fcn": (
4106 build_cond_if_const,
4107 TosaTensorGen.tgBasic,
4108 TosaArgGen.agCondIf,
4109 ),
4110 "types": [DType.BOOL],
4111 },
4112 "cond_if_binary": {
4113 "op": Op.COND_IF,
4114 "operands": (2, 0),
4115 "build_fcn": (
4116 build_cond_if_binary,
4117 TosaTensorGen.tgBasic,
4118 TosaArgGen.agCondIf,
4119 ),
4120 "types": TYPE_FI32,
4121 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004122 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004123 "while_loop": {
4124 "op": Op.WHILE_LOOP,
4125 "operands": (0, 1),
4126 "build_fcn": (
4127 build_while_loop,
4128 TosaTensorGen.tgBasic,
4129 TosaArgGen.agWhileLoop,
4130 ),
4131 "types": [DType.INT32],
4132 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004133 }
4134
Kevin Cheng550ccc52021-03-03 11:21:43 -08004135
Eric Kunzee5e26762020-10-13 16:11:07 -07004136class OutputShaper:
4137 # Methods in this class compute the expected output shape and datatype
4138 # for common classes of operations
4139 def __init__(self):
4140 pass
4141
4142 # These methods return arguments that can be used for
4143 # creating a new output tensor
4144 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004145 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4146 if error_name != ErrorIf.RankMismatch:
4147 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004148 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004149
4150 shape = []
4151 for i in range(len(a.shape)):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004152 if a.shape[i] == 1 and error_name == None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004153 shape.append(b.shape[i])
4154 else:
4155 shape.append(a.shape[i])
4156
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004157 if error_name == ErrorIf.WrongOutputType:
4158 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4159 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4160 outputDType = rng.choice(wrong_dtypes)
4161 else:
4162 outputDType = a.dtype
4163
4164 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004165
4166 @staticmethod
4167 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004168 assert len(a.shape) == len(b.shape)
4169 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004170
4171 shape = []
4172 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004173 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004174 shape.append(a.shape[i])
4175
Kevin Cheng550ccc52021-03-03 11:21:43 -08004176 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004177
4178 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004179 def unaryOp(ser, rng, a, error_name=None):
4180 if error_name == ErrorIf.WrongOutputType:
4181 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4182 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4183 outputDType = rng.choice(wrong_dtypes)
4184 else:
4185 outputDType = a.dtype
4186
4187 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004188
4189 @staticmethod
4190 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004191 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
4192 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004193
4194 shape = []
4195 for i in range(len(a.shape)):
4196 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4197
Kevin Cheng550ccc52021-03-03 11:21:43 -08004198 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004199
4200 @staticmethod
4201 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004202 assert len(a.shape) == len(b.shape)
4203 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004204
4205 # Do broadcast
4206 shape = []
4207 for i in range(len(a.shape)):
4208 if a.shape[i] == 1:
4209 shape.append(b.shape[i])
4210 else:
4211 shape.append(a.shape[i])
4212
4213 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08004214 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07004215
4216 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004217 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004218 shape = a.shape.copy()
Matthew Haddond6ce7252021-09-29 15:35:44 +01004219 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ShapeOfAxisNotOne]:
4220 shape[axis] = 1
4221 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4222 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004223
Matthew Haddond6ce7252021-09-29 15:35:44 +01004224 if error_name == ErrorIf.WrongOutputType:
4225 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4226 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4227 outputDType = rng.choice(wrong_dtypes)
4228 else:
4229 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004230
Matthew Haddond6ce7252021-09-29 15:35:44 +01004231 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004232
4233 @staticmethod
4234 def argmaxOp(ser, a, axis):
4235 shape = a.shape.copy()
4236 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004237 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07004238
4239 @staticmethod
4240 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
4241
4242 # IFM: NHWC
4243 # Filter: OHWI
4244 # OFM: NHWC
4245
4246 if len(padding) == 2:
4247 # Expand padding to 4 parameters in the case of transpose_conv2d
4248 # From H,W to T,B,L,R
4249 padding = [padding[0], padding[0], padding[1], padding[1]]
4250
Kevin Cheng550ccc52021-03-03 11:21:43 -08004251 h = (
4252 ifm.shape[1]
4253 - filter.shape[1]
4254 - (filter.shape[1] - 1) * (dilations[0] - 1)
4255 + padding[0]
4256 + padding[1]
4257 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004258
Kevin Cheng550ccc52021-03-03 11:21:43 -08004259 w = (
4260 ifm.shape[2]
4261 - filter.shape[2]
4262 - (filter.shape[2] - 1) * (dilations[1] - 1)
4263 + padding[2]
4264 + padding[3]
4265 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004266
Eric Kunzee5e26762020-10-13 16:11:07 -07004267 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4268
Kevin Cheng3a478572021-01-22 17:21:02 -08004269 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004270 out_dtype = DType.INT32
4271 elif ifm.dtype == DType.INT16:
4272 out_dtype = DType.INT48
4273 elif ifm.dtype == DType.FLOAT:
4274 out_dtype = DType.FLOAT
4275 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004276 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004277
Kevin Cheng550ccc52021-03-03 11:21:43 -08004278 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004279
4280 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -07004281 def conv3dOp(ser, ifm, filter, strides, padding, dilations):
4282
4283 # IFM: NDHWC
4284 # Filter: ODHWI
4285 # OFM: NDHWC
4286
4287 d = (
4288 ifm.shape[1]
4289 - filter.shape[1]
4290 - (filter.shape[1] - 1) * (dilations[0] - 1)
4291 + padding[0]
4292 + padding[1]
4293 ) // strides[0] + 1
4294
4295 h = (
4296 ifm.shape[2]
4297 - filter.shape[2]
4298 - (filter.shape[2] - 1) * (dilations[1] - 1)
4299 + padding[2]
4300 + padding[3]
4301 ) // strides[1] + 1
4302
4303 w = (
4304 ifm.shape[3]
4305 - filter.shape[3]
4306 - (filter.shape[3] - 1) * (dilations[2] - 1)
4307 + padding[4]
4308 + padding[5]
4309 ) // strides[2] + 1
4310
4311 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4312
4313 if ifm.dtype == DType.INT8:
4314 out_dtype = DType.INT32
4315 elif ifm.dtype == DType.INT16:
4316 out_dtype = DType.INT48
4317 elif ifm.dtype == DType.FLOAT:
4318 out_dtype = DType.FLOAT
4319 else:
4320 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
4321
4322 return ser.addOutput(ofm_shape, out_dtype)
4323
4324 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -07004325 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
4326 # IFM: NHWC
4327 # Filter: HWCM
4328 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08004329 h = (
4330 ifm.shape[1]
4331 - filter.shape[0]
4332 - (filter.shape[0] - 1) * (dilations[0] - 1)
4333 + padding[0]
4334 + padding[1]
4335 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004336
Kevin Cheng550ccc52021-03-03 11:21:43 -08004337 w = (
4338 ifm.shape[2]
4339 - filter.shape[1]
4340 - (filter.shape[1] - 1) * (dilations[1] - 1)
4341 + padding[2]
4342 + padding[3]
4343 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004344
Eric Kunzee5e26762020-10-13 16:11:07 -07004345 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4346
Kevin Cheng3a478572021-01-22 17:21:02 -08004347 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004348 out_dtype = DType.INT32
4349 elif ifm.dtype == DType.INT16:
4350 out_dtype = DType.INT48
4351 elif ifm.dtype == DType.FLOAT:
4352 out_dtype = DType.FLOAT
4353 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004354 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004355
Kevin Cheng550ccc52021-03-03 11:21:43 -08004356 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004357
4358 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004359 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004360 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004361 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
4362 # If an incorrect stride is used set dimensions to 0, test is invalid anyway.
4363 h = 1
4364 w = 1
4365 else:
4366 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
4367 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
4368
4369 if error_name == ErrorIf.PoolingOutputShapeMismatch:
4370 choices = [1, 2, 3, 4, 5]
4371 h = h + rng.choice(choices)
4372 w = w + rng.choice(choices)
Eric Kunzee5e26762020-10-13 16:11:07 -07004373
Eric Kunzee5e26762020-10-13 16:11:07 -07004374 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004375
4376 if error_name == ErrorIf.WrongOutputType:
4377 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
4378 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4379 outputDType = rng.choice(wrong_dtypes)
4380 else:
4381 outputDType = ifm.dtype
4382
4383 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004384
4385 @staticmethod
4386 def fullyConnectedOp(ser, input, filter):
4387 # input: N, IC
4388 # filter: OC, IC
4389 # output: N, OC
4390
4391 output_shape = [input.shape[0], filter.shape[0]]
4392
Kevin Cheng3a478572021-01-22 17:21:02 -08004393 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004394 out_dtype = DType.INT32
4395 elif input.dtype == DType.INT16:
4396 out_dtype = DType.INT48
4397 elif input.dtype == DType.FLOAT:
4398 out_dtype = DType.FLOAT
4399 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004400 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004401
Kevin Cheng550ccc52021-03-03 11:21:43 -08004402 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004403
4404 @staticmethod
4405 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004406 # a: N, H, C
4407 # b: N, C, W
4408 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004409
Kevin Cheng2d60f002021-06-09 14:18:32 -07004410 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004411
Kevin Cheng3a478572021-01-22 17:21:02 -08004412 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004413 out_dtype = DType.INT32
4414 elif a.dtype == DType.INT16:
4415 out_dtype = DType.INT48
4416 elif a.dtype == DType.FLOAT:
4417 out_dtype = DType.FLOAT
4418 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004419 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004420
Kevin Cheng550ccc52021-03-03 11:21:43 -08004421 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004422
4423 @staticmethod
Matthew Haddon818ab902021-07-27 09:12:49 +01004424 def concatOp(ser, axis, *a):
4425 input1 = a[0]
4426 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004427
Matthew Haddon818ab902021-07-27 09:12:49 +01004428 output_shape = input1.shape.copy()
Eric Kunzee5e26762020-10-13 16:11:07 -07004429
Matthew Haddon818ab902021-07-27 09:12:49 +01004430 output_shape[axis] = input1.shape[axis]
4431
4432 for tensor in remaining_inputs:
4433 output_shape[axis] += tensor.shape[axis]
4434
4435 return ser.addOutput(output_shape, input1.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004436
4437 @staticmethod
4438 def padOp(ser, a, padding):
4439
4440 output_shape = a.shape.copy()
4441
4442 for i in range(len(output_shape)):
4443 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4444
Kevin Cheng550ccc52021-03-03 11:21:43 -08004445 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004446
4447 @staticmethod
4448 def reshapeOp(ser, a, shape):
4449 output_shape = shape.copy()
4450
4451 totalElements = 1
4452 for i in a.shape:
4453 totalElements *= i
4454
4455 # If there are any -1 elements, figure out what that dimension must be
4456 totalOutputElements = 1
4457 for i in output_shape:
4458 if i != -1:
4459 totalOutputElements *= i
4460
4461 # And fill it in
4462 for i in range(len(output_shape)):
4463 if output_shape[i] == -1:
4464 output_shape[i] = totalElements // totalOutputElements
4465
Kevin Cheng550ccc52021-03-03 11:21:43 -08004466 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004467
4468 @staticmethod
4469 def sliceOp(ser, a, begin, size):
4470
4471 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004472 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004473
4474 @staticmethod
4475 def tileOp(ser, a, multiples):
4476
4477 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004478 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004479
4480 for i in range(len(output_shape)):
4481 output_shape[i] = a.shape[i] * multiples[i]
4482
Kevin Cheng550ccc52021-03-03 11:21:43 -08004483 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004484
4485 @staticmethod
4486 def transposeOp(ser, a, perms):
4487 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004488 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004489
4490 for i in range(len(output_shape)):
4491 output_shape[i] = a.shape[perms[i]]
4492
Kevin Cheng550ccc52021-03-03 11:21:43 -08004493 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004494
4495 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08004496 def gatherOp(ser, values, indices):
4497 assert len(values.shape) == 3
4498 assert len(indices.shape) == 2
4499 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004500
Kevin Cheng77d0f762020-11-24 10:26:32 -08004501 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4502
Kevin Cheng550ccc52021-03-03 11:21:43 -08004503 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004504
4505 @staticmethod
4506 def scatterOp(ser, values_in, indices, input):
4507 assert len(values_in.shape) == 3
4508 assert len(indices.shape) == 2
4509 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004510 assert values_in.shape[0] == indices.shape[0] # N
4511 assert input.shape[1] == indices.shape[1] # W
4512 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004513
4514 output_shape = values_in.shape
4515
Kevin Cheng550ccc52021-03-03 11:21:43 -08004516 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004517
4518 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004519 def tableOp(ser, input, table_dtype):
4520 # Same shape as the input, but dtype dependent on table dtype
4521 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
4522 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
4523 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004524
4525 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004526 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004527 serializer,
4528 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004529 input,
4530 mode,
4531 stride,
4532 offset,
4533 shift,
4534 stride_fp,
4535 offset_fp,
4536 output_dims,
4537 input_dtype,
4538 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01004539 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08004540 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01004541 if error_name == ErrorIf.WrongRank:
4542 output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
4543 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004544 if error_name == ErrorIf.BatchMismatch:
4545 output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
4546 elif error_name == ErrorIf.ChannelMismatch:
4547 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
4548 else:
4549 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004550
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004551 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004552
4553 @staticmethod
4554 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004555 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004556
4557 @staticmethod
4558 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08004559 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004560 out_dtype = DType.INT32
4561 elif ifm.dtype == DType.INT16:
4562 out_dtype = DType.INT48
4563 elif ifm.dtype == DType.FLOAT:
4564 out_dtype = DType.FLOAT
4565 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004566 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004567
Kevin Cheng550ccc52021-03-03 11:21:43 -08004568 return ser.addOutput(output_shape, out_dtype)