blob: 43b188d08b000db1e57ee8e4ebe541f86e6e7e50 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
48DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080049Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070050ResizeMode = tosa.ResizeMode.ResizeMode()
51
52class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
54
Eric Kunzee5e26762020-10-13 16:11:07 -070055 def __init__(self):
56 pass
57
58 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010059 def getQinfo(testGen, dtype, error_name=None):
60
Les Bell30e46802021-07-23 09:43:31 +010061 if dtype == DType.INT8:
62 return testGen.randInt(-128, 128)
Matthew Haddone4ecdb22021-09-28 11:38:21 +010063 elif dtype == DType.UINT8:
Les Bell30e46802021-07-23 09:43:31 +010064 return testGen.randInt(0, 256)
Matthew Haddone4ecdb22021-09-28 11:38:21 +010065 elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
66 zero_point = testGen.randInt(-128, 128)
67 if zero_point == 0:
68 zero_point = 1
69 return zero_point
Les Bell30e46802021-07-23 09:43:31 +010070 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070071
72 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010073 def qgUnary(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070074 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone4ecdb22021-09-28 11:38:21 +010075 if error_name == ErrorIf.InputZeroPointNotZero:
76 qinfo.UnaryQuantInfo(
77 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype)
78 )
79 elif error_name == ErrorIf.OutputZeroPointNotZero:
80 qinfo.UnaryQuantInfo(
81 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype, error_name)
82 )
83 else:
84 qinfo.UnaryQuantInfo(
85 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
86 )
Eric Kunzee5e26762020-10-13 16:11:07 -070087 return qinfo
88
89 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010090 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070091 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010092 if isinstance(dtype_or_dtypeList, list):
93 # a list of [input, weights, accumulator] dtypes
94 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070095 else:
Les Bell30e46802021-07-23 09:43:31 +010096 # an int, [input, weights, accumulator] dtypes are the same
97 dtypeList = [dtype_or_dtypeList] * 3
98 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
99 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
100 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -0700101 return qinfo
102
103 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100104 def qgMatmul(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700105 qinfo = ts.TosaSerializerQuantInfo()
Kevin Chengacb550f2021-06-29 15:32:19 -0700106 qinfo.MatMulQuantInfo(
107 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
108 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700109 return qinfo
110
111 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100112 def qgPad(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700113 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +0100114 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700115 return qinfo
116
117 @staticmethod
118 def computeMultiplierAndShift(scaleFp, scale32):
119 # Derived from computeMultiplierAndShiftTosaScale32
120 # Provide a floating-point scaling factor and the scale32 parameter
121 # to compute the multiplier and shift
122
123 if scale32:
124 scaleBits = 31
125 else:
126 scaleBits = 15
127
128 m, shift = math.frexp(scaleFp)
129
130 if scaleFp < 0.0:
131 m = -m
132
133 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800134 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700135
136 if multiplier == (1 << scaleBits):
137 multiplier = multiplier // 2
138 shift = shift + 1
139
140 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100141 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
142
143 # Adjust multiplier such that shift is in allowed value range.
144 if shift == 0:
145 multiplier = multiplier // 4
146 shift = shift + 2
147 elif shift == 1:
148 multiplier = multiplier // 2
149 shift = shift + 1
150 elif shift == 63:
151 multiplier = multiplier * 2
152 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700153
Kevin Cheng550ccc52021-03-03 11:21:43 -0800154 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100155 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700156
157 return multiplier, shift
158
159
Kevin Cheng550ccc52021-03-03 11:21:43 -0800160class TosaTensorGen:
161 """Tensor generators create a shape list for the placeholder and const tensor
162 data operands for the operator. The actual random data is generated separately for each test."""
163
Eric Kunzee5e26762020-10-13 16:11:07 -0700164 def __init__(self):
165 pass
166
167 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100168 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800169 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700170 shape = testGen.makeShape(rank)
171
172 shape_list = []
173 for i in range(pl + const):
174 shape_list.append(shape.copy())
175
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100176 if error_name == ErrorIf.RankMismatch:
177 if rank == 1 and i != 1:
178 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
179 elif i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
181
Eric Kunzee5e26762020-10-13 16:11:07 -0700182 return shape_list
183
184 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100185 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800186 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700187
Matthew Haddon848efb42021-09-09 12:30:53 +0100188 if error_name != ErrorIf.WrongRank:
189 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700190
191 shape = testGen.makeShape(rank)
192
193 # Constrict the batch size?
194 if testGen.args.max_batch_size:
195 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
196
197 shape_list = []
198 for i in range(pl + const):
199 shape_list.append(shape.copy())
200
201 return shape_list
202
203 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100204 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800205 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800206
Kevin Cheng550ccc52021-03-03 11:21:43 -0800207 assert pl == 2
208 assert const == 0
209 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800210
211 values_in_shape = testGen.makeShape(rank)
212
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100213 # ignore max batch size if target shape is set
214 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800215 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
216
Kevin Cheng550ccc52021-03-03 11:21:43 -0800217 W = testGen.randInt(
218 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
219 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100220 # Constrict W if one dimension is too large to keep tensor size reasonable
221 if max(values_in_shape) > 5000:
222 W = testGen.randInt(0, 16)
223
Kevin Cheng77d0f762020-11-24 10:26:32 -0800224 input_shape = [values_in_shape[0], W, values_in_shape[2]]
225
226 shape_list = []
227 shape_list.append(values_in_shape.copy())
228 shape_list.append(input_shape.copy())
229
230 return shape_list
231
232 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100233 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700234 shape = testGen.makeShape(rank)
235
Kevin Cheng550ccc52021-03-03 11:21:43 -0800236 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700237
238 shape_list = []
239
240 # Choose one of the inputs to broadcast
241 bcast_idx = testGen.randInt(0, pl + const)
242 for i in range(pl + const):
243 shape_bcast = shape.copy()
244
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100245 if error_name == ErrorIf.RankMismatch:
246 bcast_idx = -1 # Turn off broadcast because we are not testing it
247 if rank == 1 and i != 1:
248 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
249 elif i != 1:
250 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
251
Eric Kunzee5e26762020-10-13 16:11:07 -0700252 # If the chosen input, pick a random index to broadcast
253 if i == bcast_idx:
254 fuzz_idx = testGen.randInt(0, rank)
255 shape_bcast[fuzz_idx] = 1
256
257 shape_list.append(shape_bcast)
258
259 return shape_list
260
261 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100262 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800263 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700264
Kevin Cheng550ccc52021-03-03 11:21:43 -0800265 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700266
267 # IFM dimensions are NHWC
268 ifm_shape = testGen.makeShape(rank)
269
270 # Constrict the batch size?
271 if testGen.args.max_batch_size:
272 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
273
274 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800275 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700276
277 # Generate a random OFM depth
278 ofm_depth = testGen.makeShape(1)[0]
279
280 # The filter dimensions are OHWI
281 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
282
283 # The bias is OC
284 bias_shape = np.asarray([ofm_depth])
285
286 return [ifm_shape, filter_shape, bias_shape]
287
288 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100289 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700290 pl, const = op["operands"]
291
292 assert rank == 5
293
294 # IFM dimensions are NDHWC
295 ifm_shape = testGen.makeShape(rank)
296
297 # Constrict the batch size?
298 if testGen.args.max_batch_size:
299 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
300
301 # Get the filter depth/height/width from the operator parameters
302 filter_dhw = op["filter"]
303
304 # Generate a random OFM channel
305 ofm_channel = testGen.makeShape(1)[0]
306
307 # The filter dimensions are ODHWI
308 filter_shape = np.asarray(
309 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
310 )
311
312 # The bias is OC
313 bias_shape = np.asarray([ofm_channel])
314
315 return [ifm_shape, filter_shape, bias_shape]
316
317 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100318 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800319 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700320
Kevin Cheng550ccc52021-03-03 11:21:43 -0800321 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700322
323 # IFM dimensions are NHWC
324 ifm_shape = testGen.makeShape(rank)
325
326 # Constrict the batch size?
327 if testGen.args.max_batch_size:
328 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
329
330 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800331 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700332
333 # Generate a random OFM depth
334 ofm_depth = testGen.makeShape(1)[0]
335
336 # The filter dimensions are OHWI
337 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
338
Kevin Cheng989cb052021-04-28 16:29:44 -0700339 # The bias is OC
340 bias_shape = np.asarray([ofm_depth])
341
342 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700343
344 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100345 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800346 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700347
Kevin Cheng550ccc52021-03-03 11:21:43 -0800348 assert rank == 4
349 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700350
351 # IFM dimensions are NHWC
352 ifm_shape = testGen.makeShape(rank)
353
354 # Constrict the batch size?
355 if testGen.args.max_batch_size:
356 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
357
358 # Get the filter height/width from the operator parameters
359 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800360 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700361
362 # Generate a random OFM depth, but don't let it get too big because
363 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800364 filter_m = (
365 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
366 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700367
368 # The filter dimensions are HWCM
369 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
370
371 # The bias is M * C
372 bias_shape = np.asarray([ifm_shape[3] * filter_m])
373
374 return [ifm_shape, filter_shape, bias_shape]
375
376 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100377 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800378 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700379
Kevin Cheng550ccc52021-03-03 11:21:43 -0800380 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700381
382 input_shape = testGen.makeShape(rank)
Kevin Chengacb550f2021-06-29 15:32:19 -0700383 filter_oc = testGen.rng.integers(
384 low=testGen.args.tensor_shape_range[0],
385 high=testGen.args.tensor_shape_range[1],
386 size=1,
387 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700388 filter_shape = np.asarray([filter_oc, input_shape[1]])
389
390 bias_shape = np.asarray([filter_oc])
391
392 return [input_shape, filter_shape, bias_shape]
393
394 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100395 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800396 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700397
Kevin Cheng2d60f002021-06-09 14:18:32 -0700398 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800399 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700400
401 a_shape = testGen.makeShape(rank)
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100402 # Get a random number for b_oc even if target shape is defined
403 b_oc = np.int32(
404 testGen.rng.integers(
405 low=testGen.args.tensor_shape_range[0],
406 high=testGen.args.tensor_shape_range[1],
407 size=1,
408 )
409 )[0]
410 # If N or H is large let b_oc be 1 to reduce output tensor size
411 if max(a_shape) > 1000:
412 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700413
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100414 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700415 return [a_shape, b_shape]
416
Matthew Haddon818ab902021-07-27 09:12:49 +0100417 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100418 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100419 pl, const = opName["operands"]
420 shape = testGen.makeShape(rank)
421
422 # Create extra tensors to concat.
423 # Take into account value of pl when getting maximum number of concats
424 num_tensors = testGen.randInt(0, 4)
425 shape_list = []
426 for i in range(pl + const + num_tensors):
427 shape_list.append(shape.copy())
428
429 return shape_list
430
431 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100432 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100433 # Split concat shape along axis to allow for multiple const inputs
434 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100435 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddon818ab902021-07-27 09:12:49 +0100436 return shapeList
437
Jeremy Johnson960985a2021-10-06 10:58:14 +0100438 # Create copy of shape we are going to split (so we don't alter shapeList)
439 shape = shapeList[0].copy()
440 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100441 new_shapeList = [shape.copy()]
442 length_on_axis = shape[axis]
443 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700444 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100445 # Calculate split on axis and remaining value
446 split_shape_val = int(shape[axis] / 2)
447 remaining_length = remaining_length - split_shape_val
448
449 # Append new shape, and set remaining shape
450 shape[axis] = split_shape_val
451 new_shapeList.append(shape.copy())
452 shape[axis] = remaining_length
453 if i == len(shapeList) - 3:
454 new_shapeList.append(shape.copy())
455
456 return new_shapeList
457
458
Eric Kunzee5e26762020-10-13 16:11:07 -0700459class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800460 """Argument generators create exhaustive or random lists of attributes for operators that take
461 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
462 tuples where the descriptive_name is appended to the test name and the arglist is expanded
463 as arguments to the operator build function."""
464
Eric Kunzee5e26762020-10-13 16:11:07 -0700465 def __init__(self):
466 pass
467
468 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100469 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800470 """A trivial argument generator for operators that don't take any
471 non-tensor arguments"""
472 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700473
474 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100475 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800476 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700477 axes = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700478 shape = shapeList[0]
479
Matthew Haddond6ce7252021-09-29 15:35:44 +0100480 if error_name == ErrorIf.AxisSmallerZero:
481 small_axis = testGen.rng.integers(-5, 0)
482 axes.append(("axis{}".format(small_axis), [small_axis]))
483 elif error_name == ErrorIf.AxisLargerRank:
484 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
485 axes.append(("axis{}".format(large_axis), [large_axis]))
486 else:
487 for a in range(0, len(shape)):
488 axes.append(("axis{}".format(a), [a]))
489
Eric Kunzee5e26762020-10-13 16:11:07 -0700490 return axes
491
492 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100493 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700494 arg_list = []
495
496 ifm_shape = shapeList[0]
497 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100498 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
499 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700500
Les Bell7aa69f42021-09-20 10:44:07 +0100501 # Check the rank
502 rank = 5 if opName.startswith("conv3d") else 4
503 assert len(ifm_shape) == rank
504 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700505
Les Bell7aa69f42021-09-20 10:44:07 +0100506 # kernel rank omits batch and channels
507 k_rank = rank - 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700508
Les Bell7aa69f42021-09-20 10:44:07 +0100509 # Generate comprehensive argument lists
510 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
511 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
512 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
513 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
514 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
515 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700516
Les Bell7aa69f42021-09-20 10:44:07 +0100517 # add some oversize argument values
518 if max(ifm_shape) < 64:
519 bigPadding = 9
520 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
521 bigStride = 8
522 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
523 bigDilation = 7
524 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100525
526 # There are too many parameter combinations, so generate them sparsely
Les Bell7aa69f42021-09-20 10:44:07 +0100527 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
528 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
529 if sparsity < 13:
530 sparsity = 1
531 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
532 sparsity += 1
Les Bellf414b3c2021-09-06 11:29:46 +0100533 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100534 for s in sorted(list(strides)):
535 for p in sorted(list(paddings)):
536 for d in sorted(list(dilations)):
537 if (n % sparsity == 0
538 # padding must not exceed the kernel size ?
539 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
540 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
541 # the padded shape must exceed the kernel size
542 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
543 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
544 # the padded shape must exceed the dilation
545 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
546 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
547 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100548 arg_list.append(
549 (
550 "st{}_pad{}_dilat{}".format(
551 "".join([str(x) for x in s]),
552 "".join([str(x) for x in p]),
553 "".join([str(x) for x in d]),
554 ),
555 [s, p, d],
556 )
557 )
558 n += 1
559
Kevin Cheng1533b852021-09-01 12:51:58 -0700560 return arg_list
561
562 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100563 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700564 arg_list = []
565
566 ifm_shape = shapeList[0]
567 filter_shape = shapeList[1]
568
569 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800570 assert len(ifm_shape) == 4
571 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700572
Les Bell7aa69f42021-09-20 10:44:07 +0100573 # Generate comprehensive argument lists
574 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
575 paddings = {x for x in itertools.product(*([p_vals] * 2))}
576 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
577 strides = {x for x in itertools.product(*([s_vals] * 2))}
578 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
579 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700580
Les Bell7aa69f42021-09-20 10:44:07 +0100581 # add some oversize argument values
582 if max(ifm_shape) < 64:
583 bigPadding = 9
584 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
585 bigStride = 8
586 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
587 bigDilation = 7
588 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700589
Les Bell7aa69f42021-09-20 10:44:07 +0100590 # There are too many parameter combinations, so generate them sparsely
591 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
592 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
593 if sparsity < 13:
594 sparsity = 1
595 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
596 sparsity += 1
597 n = 0
598 for s in sorted(list(strides)):
599 for p in sorted(list(paddings)):
600 for d in sorted(list(dilations)):
601 if n % sparsity == 0:
602 # Determine the output shape
603 oh = (
604 ifm_shape[1]
605 - filter_shape[1]
606 - (filter_shape[1] - 1) * (d[0] - 1)
607 + 2 * p[0]
608 ) // s[0] + 1
609 ow = (
610 ifm_shape[2]
611 - filter_shape[2]
612 - (filter_shape[2] - 1) * (d[1] - 1)
613 + 2 * p[1]
614 ) // s[1] + 1
615 os = [ifm_shape[0], oh, ow, filter_shape[0]]
616 arg_list.append(
617 (
618 "st{}_pad{}_dilat{}_os{}".format(
619 "".join([str(x) for x in s]),
620 "".join([str(x) for x in p]),
621 "".join([str(x) for x in d]),
622 "x".join([str(x) for x in os]),
623 ),
624 [s, p, d, os],
625 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800626 )
Les Bell7aa69f42021-09-20 10:44:07 +0100627 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700628
629 return arg_list
630
631 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100632 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700633 arg_list = []
634 rank = len(shapeList[0])
635
Les Bell7ffccce2021-07-28 15:37:02 +0100636 # Exhaustively test combinations of padding on each side of each dimension
637 # - the range of padding values is defined by pad_min and pad_max
638 # - for padding >9, the name format needs to be more distinctive
639 pad_min, pad_max = 0, 1
640 pad_values = [x for x in range(pad_min, pad_max + 1)]
641 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
642 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700643
Les Bell7ffccce2021-07-28 15:37:02 +0100644 for paddings in shape_pad_values:
645 name = "pad"
646 for r in range(rank):
647 before, after = paddings[r]
648 name = f"{name}{before}{after}"
649 arg_list.append((name, [np.array(paddings)]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700650
651 return arg_list
652
653 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100654 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700655 arg_list = []
656
657 shape = shapeList[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800658 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700659
Les Bell7aa69f42021-09-20 10:44:07 +0100660 # Generate comprehensive argument lists
661 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
662 paddings = {x for x in itertools.product(*([p_vals] * 4))}
663 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
664 strides = {x for x in itertools.product(*([s_vals] * 2))}
665 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
666 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700667
Les Bell7aa69f42021-09-20 10:44:07 +0100668 # add some oversize argument values
669 bigStride = 7
670 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
671 bigKernel = 6
672 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
673 if max(shape) < 64:
674 # padding must be less than the kernel size
675 bigPadding = bigKernel - 1
676 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700677
Les Bell7aa69f42021-09-20 10:44:07 +0100678 # There are too many parameter combinations, so generate them sparsely
679 sparsity = len(paddings) * len(strides) * len(kernels) // 500 + 1
680 n = 0
681 for s in sorted(list(strides)):
682 for p in sorted(list(paddings)):
683 for k in sorted(list(kernels)):
684 if (n % sparsity == 0
685 # padding must not exceed the kernel size
686 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
687 # the padded shape must exceed the kernel size
688 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
689 ):
690 arg_list.append(
691 (
692 "st{}_kern{}_pad{}".format(
693 "".join([str(x) for x in s]),
694 "".join([str(x) for x in k]),
695 "".join([str(x) for x in p]),
696 ),
697 [s, p, k],
698 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800699 )
Les Bell7aa69f42021-09-20 10:44:07 +0100700 n += 1
701
Eric Kunzee5e26762020-10-13 16:11:07 -0700702 return arg_list
703
704 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100705 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700706 arg_list = []
707
708 # Enumerate the output types here
709 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800710 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700711 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800712 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700713 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800714 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700715 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800716 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700717 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800718 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700719 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800720 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700721
722 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800723 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700724
725 return arg_list
726
727 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100728 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700729 arg_list = []
730
731 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100732 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
733 if inDtype == DType.UINT8 and dtype != DType.INT8:
734 # The only output dtype for UINT8 is INT8, skip all other combinations
735 continue
736 if inDtype != DType.INT8 and dtype == DType.UINT8:
737 # The only input dtype for UINT8 is INT8, skip all other combinations
738 continue
739
Kevin Cheng550ccc52021-03-03 11:21:43 -0800740 for scale32 in [False, True]:
741 for double_round in [False, True]:
742 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700743
744 if inDtype == DType.INT48 and scale32:
745 # Illegal condition. Must be scale32=False
746 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100747 if double_round and not scale32:
748 # Illegal condition. ERROR_IF(!scale32 && double_round)
749 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700750
Kevin Cheng550ccc52021-03-03 11:21:43 -0800751 arg_list.append(
752 (
753 "out{}_sc{}_dr{}_pc{}".format(
754 DTypeNames[dtype],
755 int(scale32),
756 int(double_round),
757 int(per_channel),
758 ),
759 [dtype, scale32, double_round, per_channel],
760 )
761 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700762
763 return arg_list
764
Kevin Chengaee1fac2020-11-11 13:54:06 -0800765 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100766 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800767 arg_list = []
768
769 if dtype is DType.INT32:
770 for p in range(testGen.args.num_rand_permutations):
771
772 shift = testGen.randInt(0, 32)
773
Kevin Cheng550ccc52021-03-03 11:21:43 -0800774 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800775 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100776 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800777
778 return arg_list
779
780 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100781 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800782 arg_list = []
783
Kevin Cheng550ccc52021-03-03 11:21:43 -0800784 arg_list.append(("roundTrue", [True]))
785 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800786
787 return arg_list
788
Eric Kunzee5e26762020-10-13 16:11:07 -0700789 # Helper function for reshape. Gets some factors of a larger number.
790 @staticmethod
791 def getFactors(val, start=1):
792 factors = []
793
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100794 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700795 if (val % i) == 0:
796 factors.append(i)
797
798 return factors
799
800 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100801 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700802 arg_list = []
803
804 origShape = shapeList[0]
805
806 totalElements = 1
807 for s in origShape:
808 totalElements *= s
809
810 # This code is NOT fast. Fortunately, the numbers are fairly small.
811 factors = TosaArgGen.getFactors(totalElements)
812
813 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100814 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800815 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700816 continue
817
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100818 found = True
819 # escape_counter breaks while loop if it continues on for too long
820 escape_counter = 0
821 while found:
822 newShape = []
823 # Generate newShape ensuring it isn't a duplicate
824 remainingElements = totalElements
825 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100826 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100827 # pick rank-1 factors
828 newShape.append(shuffledFactors[0])
829 remainingElements = remainingElements // shuffledFactors[0]
830 shuffledFactors = testGen.rng.permutation(
831 TosaArgGen.getFactors(remainingElements)
832 )
833 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700834
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100835 # Toss in a -1 sometimes
836 minusOne = testGen.randInt(0, newRank * 4)
837 if minusOne < newRank:
838 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700839
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100840 # Check for duplicates
841 found = False
842 for name, other_shape in arg_list:
843 if other_shape[0] == newShape:
844 found = True
845 break
846
847 escape_counter += 1
848 if escape_counter >= 100:
849 break
850
851 if not found:
852 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700853
854 return arg_list
855
Eric Kunzee5e26762020-10-13 16:11:07 -0700856 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100857 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700858 arg_list = []
859
860 ifm_shape = shapeList[0]
861
Jeremy Johnsona6185572021-06-21 15:55:35 +0100862 # Get all permutations
863 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700864
Jeremy Johnsona6185572021-06-21 15:55:35 +0100865 # Limit to possible permutations from shape dimension or argument setting
866 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700867
Jeremy Johnsona6185572021-06-21 15:55:35 +0100868 # Get random permutation generator that uses all permutations
869 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700870
Jeremy Johnsona6185572021-06-21 15:55:35 +0100871 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -0700872 arg_list = [
873 ("perm{}".format(p), [random_permutations[p].tolist()])
874 for p in range(limit)
875 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700876 return arg_list
877
878 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100879 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700880 arg_list = []
881
882 ifm_shape = shapeList[0]
883 rank = len(ifm_shape)
884
885 for p in range(testGen.args.num_rand_permutations):
886 begin = []
887 size = []
888
Kevin Cheng550ccc52021-03-03 11:21:43 -0800889 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700890
891 for i in range(rank):
892 if ifm_shape[i] > 1:
893 begin.append(testGen.randInt(0, ifm_shape[i]))
894 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
895
896 # Invalid slice size?
897 if size[i] == 0:
898 valid = False
899 else:
900 begin.append(0)
901 size.append(1)
902
903 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800904 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700905 return arg_list
906
907 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100908 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700909 arg_list = []
910
911 ifm_shape = shapeList[0]
912 rank = len(ifm_shape)
913
914 for p in range(testGen.args.num_rand_permutations):
915
916 # Pick a few random, but small multiple values
917 # because otherwise this has a tendency to generate
918 # enormous tensors
919 multiples = []
920 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +0100921 if ifm_shape[i] > 1000:
922 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
923 multiples.append(1)
924 elif max(ifm_shape) > 1000:
925 multiples.append(2)
926 else:
927 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800928 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700929
930 return arg_list
931
932 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100933 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700934 arg_list = []
935
936 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +0100937 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700938
939 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +0100940 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100941 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +0100942 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800943 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +0100944 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100945 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +0100946 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800947 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800948 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800949 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +0100950 elif error_name == ErrorIf.WrongInputType:
951 # If an incorrect input type is used then we set a 'correct'
952 # output type to avoid other errors
953 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700954 else:
955 continue
956
957 for outputDType in outputDTypeList:
958 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -0700959 # Randomly generate legal output dimensions and shift
960 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +0100961 # A output_dim of 1 will cause offset to exceed allowed range
962 # so minimum value 2 produced below
963 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
964 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
965 output_dims[0] += 1
966 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
967 output_dims[1] += 1
968
Kevin Cheng77d0f762020-11-24 10:26:32 -0800969 in_center_h = (ifm_shape[1] - 1) / 2.0
970 in_center_w = (ifm_shape[2] - 1) / 2.0
971 out_center_h = (output_dims[0] - 1) / 2.0
972 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700973
Kevin Cheng77d0f762020-11-24 10:26:32 -0800974 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
975 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
976 fp_offset_y = in_center_h - fp_stride_y * out_center_h
977 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -0700978
Kevin Cheng77d0f762020-11-24 10:26:32 -0800979 if outputDType == DType.FLOAT:
980 shift = 0
981 stride = [0, 0]
982 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800983 stride_fp = [fp_stride_y, fp_stride_x]
984 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +0100985
986 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +0100987 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +0100988 testGen,
989 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +0100990 mode,
991 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +0100992 shapeList,
993 outputDType,
994 shift,
995 stride,
996 stride_fp,
997 offset,
998 offset_fp
999 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001000 else:
1001 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001002
Kevin Cheng550ccc52021-03-03 11:21:43 -08001003 arg_list.append(
1004 (
1005 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001006 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001007 output_dims[0],
1008 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001009 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001010 stride_fp[0],
1011 stride_fp[1],
1012 offset_fp[0],
1013 offset_fp[1],
1014 ),
1015 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001016 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001017 stride,
1018 offset,
1019 shift,
1020 stride_fp,
1021 offset_fp,
1022 output_dims,
1023 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001024 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001025 ],
1026 )
1027 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001028 else:
1029 shift = 11
1030 unit = float(1 << shift)
1031 stride_y = int(round(fp_stride_y * unit))
1032 stride_x = int(round(fp_stride_x * unit))
1033 offset_y = int(round(fp_offset_y * unit))
1034 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001035
Kevin Cheng550ccc52021-03-03 11:21:43 -08001036 while (
Matthew Haddone86fd342021-09-07 16:12:21 +01001037 stride_y >= (16 << shift)
1038 or stride_x >= (16 << shift)
1039 or offset_y >= (16 << shift)
1040 or offset_x >= (16 << shift)
1041 or offset_y <= (-16 << shift)
1042 or offset_x <= (-16 << shift)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001043 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001044 shift = shift - 1
1045 unit = float(1 << shift)
1046 stride_y = int(round(fp_stride_y * unit))
1047 stride_x = int(round(fp_stride_x * unit))
1048 offset_y = int(round(fp_offset_y * unit))
1049 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001050
Kevin Cheng550ccc52021-03-03 11:21:43 -08001051 stride = [stride_y, stride_x]
1052 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001053
1054 stride_fp = [0.0, 0.0]
1055 offset_fp = [0.0, 0.0]
1056
Matthew Haddone86fd342021-09-07 16:12:21 +01001057 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001058 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001059 testGen,
1060 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001061 mode,
1062 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001063 shapeList,
1064 outputDType,
1065 shift,
1066 stride,
1067 stride_fp,
1068 offset,
1069 offset_fp
1070 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001071 else:
1072 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001073
Kevin Cheng550ccc52021-03-03 11:21:43 -08001074 arg_list.append(
1075 (
1076 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001077 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001078 shift,
1079 output_dims[0],
1080 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001081 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001082 stride[0],
1083 stride[1],
1084 offset[0],
1085 offset[1],
1086 ),
1087 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001088 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001089 stride,
1090 offset,
1091 shift,
1092 stride_fp,
1093 offset_fp,
1094 output_dims,
1095 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001096 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001097 ],
1098 )
1099 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001100
1101 return arg_list
1102
Matthew Haddon1c00b712021-10-01 15:51:03 +01001103 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001104 # CondIf generates the condition values here.
1105 # Convert to tensors in the build function, along with the
1106 # then and else blocks
1107 arg_list = []
1108
1109 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001110 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001111
1112 return arg_list
1113
Matthew Haddon1c00b712021-10-01 15:51:03 +01001114 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001115 # While loop: 0 iterations, 1, more than 1
1116 arg_list = []
1117
1118 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001119 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001120
1121 return arg_list
1122
Matthew Haddone86fd342021-09-07 16:12:21 +01001123class TosaErrorIfArgGen:
1124
1125 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001126 def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
Matthew Haddone86fd342021-09-07 16:12:21 +01001127
1128 if outputDType == DType.FLOAT:
1129 if error_name == ErrorIf.StrideSmallerEqualZero:
1130 stride_fp = testGen.rng.random(size=[2]) - 2
1131 elif error_name == ErrorIf.ShiftNotZero:
1132 shift = testGen.rng.integers(1, 5)
1133 elif error_name == ErrorIf.StrideLargerDimension:
1134 shape = shapeList[0]
1135 transform_height = testGen.rng.choice([False, True])
1136 if transform_height:
1137 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1138 else:
1139 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1140 else:
1141 if error_name == ErrorIf.StrideSmallerEqualZero:
1142 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1143 elif error_name == ErrorIf.ShiftSmallerOne:
1144 shift = testGen.rng.integers(-3, 1)
1145 if shift <= 0:
1146 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1147 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1148 else:
1149 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1150 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1151 elif error_name == ErrorIf.ShiftLargerEleven:
1152 shift = np.int16(testGen.rng.integers(12, 15))
1153 elif error_name == ErrorIf.StrideLargerDimension:
1154 shape = shapeList[0]
1155 transform_height = testGen.rng.choice([False, True])
1156 if transform_height:
1157 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1158 else:
1159 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1160 elif error_name == ErrorIf.StrideLargerEqualMax:
1161 stride = [(16 << shift) + 1, (16 << shift) + 1]
1162 elif error_name == ErrorIf.OffsetLargerEqualMax:
1163 offset = [(16 << shift) + 1, (16 << shift) + 1]
1164 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1165 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1166
Matthew Haddon1c00b712021-10-01 15:51:03 +01001167
Matthew Haddon848efb42021-09-09 12:30:53 +01001168 if error_name == ErrorIf.WrongOutputType:
1169 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1170 incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1171 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1172 incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
1173 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1174 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
1175 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1176 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
1177 elif dtype == DType.FLOAT:
1178 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
1179 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001180
Matthew Haddon848efb42021-09-09 12:30:53 +01001181 return shift, stride, stride_fp, offset, offset_fp, outputDType
1182
1183 @staticmethod
1184 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1185 # Mess up input/output tensors for ERROR_IF checks
1186 if error_name == "WrongInputList":
1187 add_input = testGen.rng.choice([True, False])
1188 if add_input:
1189 input_list.append('eiDummyInput')
1190 else:
1191 input_list = input_list[:-1]
1192 if error_name == "WrongOutputList":
1193 add_output = testGen.rng.choice([True, False])
1194 if add_output:
1195 output_list.append('eiDummyOutput')
1196 else:
1197 output_list = []
1198 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001199
1200class TosaErrorValidator:
1201
Matthew Haddon848efb42021-09-09 12:30:53 +01001202 @staticmethod
1203 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
1204 # Check ERROR_IF statements
1205
1206 for val_fcn in validator_fcns:
1207 val_result = val_fcn(True, **kwargs)
1208
1209 validator_name = val_result['error_name']
1210 error_result = val_result['error_result']
1211 error_reason = val_result['error_reason']
1212
1213 if error_result:
1214 if error_name == validator_name:
1215 serializer.setExpectedReturnCode(2, error_reason)
1216 else:
1217 print(f"Multiple ERROR_IF checks hit \nError required: {error_name}, Error_produced: {validator_name}")
1218 return None # Return None to delete test if wrong ERROR_IF is hit
1219 else:
1220 if error_name == validator_name:
1221 print(f"No ERROR_IF hit for {error_name}")
1222 return None
1223
1224 @staticmethod
1225 def evWrongInputType(check=False, **kwargs):
1226 all_dtypes = (DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1227
1228 # Find the unsupported input data types
1229 assert 'op' in kwargs
1230 op = kwargs['op']
1231 input_dtypes = op['types']
1232 wrong_input_dtypes = list(set(all_dtypes) - set(input_dtypes))
1233
1234 error_name = ErrorIf.WrongInputType
1235 param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
1236 error_result = False
1237 error_reason = "Input data type not supported for this operator"
1238
1239 if check:
1240 input_dtype = kwargs['input_dtype']
1241 if input_dtype not in input_dtypes:
1242 error_result = True
1243
1244 info_dict = {
1245 "error_name": error_name,
1246 "error_result": error_result,
1247 "error_reason": error_reason,
1248 "param_reqs": param_reqs
1249 }
1250 return info_dict
1251
1252 @staticmethod
1253 def evWrongOutputType(check=False, **kwargs):
1254 error_name = ErrorIf.WrongOutputType
1255 param_reqs = {"rank": None, "dtype": None, "shape": None}
1256 error_result = False
1257 error_reason = "Output data type not supported for this configuration of operator"
1258
1259 if check:
1260 input_dtype = kwargs['input_dtype']
1261 output_dtype = kwargs['output_dtype']
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001262 op = kwargs['op']
Matthew Haddon848efb42021-09-09 12:30:53 +01001263
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001264 if op['op'] == Op.RESIZE:
1265 mode = kwargs['mode']
1266 if (
1267 (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
1268 (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
1269 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1270 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1271 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1272 ):
1273 error_result = True
1274 else:
1275 if output_dtype != input_dtype:
1276 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001277
1278 info_dict = {
1279 "error_name": error_name,
1280 "error_result": error_result,
1281 "error_reason": error_reason,
1282 "param_reqs": param_reqs
1283 }
1284 return info_dict
1285
1286 @staticmethod
1287 def evWrongRank(check=False, **kwargs):
1288 all_ranks = (1, 2, 3, 4, 5)
1289
1290 # Make a list of incorrect ranks
1291 assert 'op' in kwargs
1292 op = kwargs['op']
1293 rmin, rmax = op['rank']
1294 rank_range = range(rmin, rmax + 1)
1295 incorrect_ranks = list(set(all_ranks) - set(rank_range))
1296 # Set minimum incorrect rank to 3 to avoid index error
1297 if op['op'] == Op.RESIZE:
1298 incorrect_ranks = [3, 5]
1299
1300 error_name = ErrorIf.WrongRank
1301 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1302 error_result = False
1303 error_reason = "Rank not supported for this operator"
1304
1305 if check:
1306 input_shape = kwargs['input_shape']
1307 if op['op'] == Op.RESIZE and len(input_shape.shape) != 4:
1308 error_result = True
1309
1310 info_dict = {
1311 "error_name": error_name,
1312 "error_result": error_result,
1313 "error_reason": error_reason,
1314 "param_reqs": param_reqs
1315 }
1316 return info_dict
1317
1318 @staticmethod
1319 def evWrongInputList(check=False, **kwargs):
1320 error_name = ErrorIf.WrongInputList
1321 param_reqs = {"rank": None, "dtype": None, "shape": None}
1322 error_result = False
1323 error_reason = "Op input list does not match expected input"
1324
1325 if check:
1326 op = kwargs['op']
1327 input_list = kwargs['input_list']
1328 num_operands = kwargs['num_operands']
1329 if len(input_list) != num_operands:
1330 error_result = True
1331
1332 info_dict = {
1333 "error_name": error_name,
1334 "error_result": error_result,
1335 "error_reason": error_reason,
1336 "param_reqs": param_reqs
1337 }
1338 return info_dict
1339
1340 @staticmethod
1341 def evWrongOutputList(check=False, **kwargs):
1342 error_name = ErrorIf.WrongOutputList
1343 param_reqs = {"rank": None, "dtype": None, "shape": None}
1344 error_result = False
1345 error_reason = "Op output list does not match expected output"
1346
1347 if check:
1348 output_list = kwargs['output_list']
1349 # Note this will be incorrect if an operator returns more than one output
1350 if len(output_list) != 1:
1351 error_result = True
1352
1353 info_dict = {
1354 "error_name": error_name,
1355 "error_result": error_result,
1356 "error_reason": error_reason,
1357 "param_reqs": param_reqs
1358 }
1359 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01001360
1361 @staticmethod
1362 def evMaxDimExceeded(check=False, **kwargs):
1363 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01001364 param_reqs = {
1365 "rank": [4,4],
1366 "dtype": [DType.INT8],
1367 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
1368 }
Matthew Haddone86fd342021-09-07 16:12:21 +01001369 error_result = False
1370 error_reason = "At least one maximum dimension is larger than 16384"
1371
1372 if check:
1373 input_shape = kwargs['input_shape'].shape
1374 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
1375 if ((input_shape[1] > 16384) or
1376 (input_shape[2] > 16384) or
1377 (output_shape[0] > 16384) or
1378 (output_shape[1] > 16384)):
1379 error_result = True
1380
1381 info_dict = {
1382 "error_name": error_name,
1383 "error_result": error_result,
1384 "error_reason": error_reason,
1385 "param_reqs": param_reqs
1386 }
1387 return info_dict
1388
1389 @staticmethod
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001390 def evBatchMismatch(check=False, **kwargs):
1391 error_name = ErrorIf.BatchMismatch
1392 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1393 error_result = False
1394 error_reason = "Input batch size not equal to output batch size"
1395
1396 assert 'op' in kwargs
1397 op = kwargs['op']
1398 rmin, rmax = op['rank']
1399 rank_range = range(rmin, rmax + 1)
1400
1401 if check:
1402 input_shape = kwargs['input_shape'].shape
1403 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1404
1405 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
1406 error_result = True
1407
1408 info_dict = {
1409 "error_name": error_name,
1410 "error_result": error_result,
1411 "error_reason": error_reason,
1412 "param_reqs": param_reqs
1413 }
1414 return info_dict
1415
1416 @staticmethod
1417 def evChannelMismatch(check=False, **kwargs):
1418 error_name = ErrorIf.ChannelMismatch
1419 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1420 error_result = False
1421 error_reason = "Input channel size not equal to output channel size"
1422
1423 assert 'op' in kwargs
1424 op = kwargs['op']
1425 rmin, rmax = op['rank']
1426 rank_range = range(rmin, rmax + 1)
1427
1428 if check:
1429 input_shape = kwargs['input_shape'].shape
1430 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1431 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
1432 error_result = True
1433
1434 info_dict = {
1435 "error_name": error_name,
1436 "error_result": error_result,
1437 "error_reason": error_reason,
1438 "param_reqs": param_reqs
1439 }
1440 return info_dict
1441
1442 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001443 def evStrideSmallerEqualZero(check=False, **kwargs):
1444 error_name = ErrorIf.StrideSmallerEqualZero
1445 param_reqs = {"rank": None, "dtype": None, "shape": None}
1446 error_result = False
1447 error_reason = "Stride value smaller than or equal zero"
1448
1449 if check:
1450 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001451 output_dtype = kwargs['output_dtype']
1452 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
1453 stride = kwargs['stride'] # Work around wrong input/output type tests
1454 elif output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001455 stride = kwargs['stride_fp']
Matthew Haddon848efb42021-09-09 12:30:53 +01001456 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1457 stride = kwargs['stride_fp'] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01001458 else:
1459 stride = kwargs['stride']
1460
1461 if min(stride) <= 0:
1462 error_result = True
1463
1464 info_dict = {
1465 "error_name": error_name,
1466 "error_result": error_result,
1467 "error_reason": error_reason,
1468 "param_reqs": param_reqs
1469 }
1470 return info_dict
1471
1472 @staticmethod
1473 def evStrideLargerEqualMax(check=False, **kwargs):
1474 error_name = ErrorIf.StrideLargerEqualMax
1475 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1476 error_result = False
1477 error_reason = "Stride value larger than or equal to maximum value"
1478
1479 if check:
1480 shift = kwargs['shift']
1481 input_dtype = kwargs['input_dtype']
1482 stride = kwargs['stride']
1483 if input_dtype in [DType.INT8, DType.INT16]:
1484 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1485 error_result = True
1486 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1487 error_result = True
1488
1489 info_dict = {
1490 "error_name": error_name,
1491 "error_result": error_result,
1492 "error_reason": error_reason,
1493 "param_reqs": param_reqs
1494 }
1495 return info_dict
1496
1497
1498 @staticmethod
1499 def evStrideLargerDimension(check=False, **kwargs):
1500 error_name = ErrorIf.StrideLargerDimension
1501 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1502 error_result = False
1503 error_reason = "Stride value larger than or equal to H/W dimension"
1504
1505 if check:
1506 shape = kwargs['input_shape'].shape
1507 input_dtype = kwargs['input_dtype']
1508 stride = kwargs['stride_fp']
1509
1510 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1511 error_result = True
1512
1513 info_dict = {
1514 "error_name": error_name,
1515 "error_result": error_result,
1516 "error_reason": error_reason,
1517 "param_reqs": param_reqs
1518 }
1519 return info_dict
1520
1521
1522 @staticmethod
1523 def evOffsetSmallerEqualMin(check=False, **kwargs):
1524 error_name = ErrorIf.OffsetSmallerEqualMin
1525 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1526 error_result = False
1527 error_reason = "Offset value smaller than or equal to minimum value"
1528
1529 if check:
1530 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001531 output_dtype = kwargs['output_dtype']
1532 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001533 offset = kwargs['offset_fp']
1534 else:
1535 offset = kwargs['offset']
1536
1537 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
1538 error_result = True
1539 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
1540 error_result = True
1541
1542 info_dict = {
1543 "error_name": error_name,
1544 "error_result": error_result,
1545 "error_reason": error_reason,
1546 "param_reqs": param_reqs
1547 }
1548 return info_dict
1549
1550 @staticmethod
1551 def evOffsetLargerEqualMax(check=False, **kwargs):
1552 error_name = ErrorIf.OffsetLargerEqualMax
1553 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1554 error_result = False
1555 error_reason = "Offset value larger than or equal to maximum value"
1556
1557 if check:
1558 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001559 output_dtype = kwargs['output_dtype']
1560 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001561 offset = kwargs['offset_fp']
1562 else:
1563 offset = kwargs['offset']
1564
1565 if shift >= 0:
1566 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
1567 error_result = True
1568
1569 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
1570 error_result = True
1571 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
1572 error_result = True
1573
1574 info_dict = {
1575 "error_name": error_name,
1576 "error_result": error_result,
1577 "error_reason": error_reason,
1578 "param_reqs": param_reqs
1579 }
1580 return info_dict
1581
1582 @staticmethod
1583 def evShiftNotZero(check=False, **kwargs):
1584 error_name = ErrorIf.ShiftNotZero
1585 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1586 error_result = False
1587 error_reason = "Shift value must be zero for float input"
1588
1589 if check:
1590 shift = kwargs['shift']
1591 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001592 output_dtype = kwargs['output_dtype']
1593 if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
Matthew Haddone86fd342021-09-07 16:12:21 +01001594 error_result = True
1595
1596 info_dict = {
1597 "error_name": error_name,
1598 "error_result": error_result,
1599 "error_reason": error_reason,
1600 "param_reqs": param_reqs
1601 }
1602 return info_dict
1603
1604
1605 @staticmethod
1606 def evShiftSmallerOne(check=False, **kwargs):
1607 error_name = ErrorIf.ShiftSmallerOne
1608 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1609 error_result = False
1610 error_reason = "Shift value smaller than one"
1611
1612 if check:
1613 shift = kwargs['shift']
1614 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001615 output_dtype = kwargs['output_dtype']
1616 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001617 error_result = True
1618
1619 info_dict = {
1620 "error_name": error_name,
1621 "error_result": error_result,
1622 "error_reason": error_reason,
1623 "param_reqs": param_reqs
1624 }
1625 return info_dict
1626
1627 @staticmethod
1628 def evShiftLargerEleven(check=False, **kwargs):
1629 error_name = ErrorIf.ShiftLargerEleven
1630 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1631 error_result = False
1632 error_reason = "Shift value larger than eleven"
1633
1634 if check:
1635 shift = kwargs['shift']
1636 if shift > 11:
1637 error_result = True
1638
1639 info_dict = {
1640 "error_name": error_name,
1641 "error_result": error_result,
1642 "error_reason": error_reason,
1643 "param_reqs": param_reqs
1644 }
1645 return info_dict
1646
1647
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001648 @staticmethod
1649 def evRankMismatch(check=False, **kwargs):
1650 error_name = ErrorIf.RankMismatch
1651 param_reqs = {"rank": None, "dtype": None, "shape": None}
1652 error_result = False
1653 error_reason = "Input Rank does not match output rank"
1654
1655 if check:
1656 input1_shape = kwargs['input1'].shape
1657 input2_shape = kwargs['input2'].shape
1658 output_shape = kwargs['result_tensor'].shape
1659 if (len(input1_shape) != len(output_shape)) or (len(input2_shape) != len(output_shape)):
1660 error_result = True
1661
1662 info_dict = {
1663 "error_name": error_name,
1664 "error_result": error_result,
1665 "error_reason": error_reason,
1666 "param_reqs": param_reqs
1667 }
1668 return info_dict
1669
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001670 @staticmethod
1671 def evInputZeroPointNotZero(check=False, **kwargs):
1672 error_name = ErrorIf.InputZeroPointNotZero
1673 param_reqs = {
1674 "rank": None,
1675 "dtype": [DType.INT16, DType.INT32, DType.FLOAT],
1676 "shape": None
1677 }
1678 error_result = False
1679 error_reason = "Input DType not INT8 and zero point not 0"
1680
1681 if check:
1682 input_dtype = kwargs['input_dtype']
1683 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1684 qinfo = kwargs['qinfo'].ints
1685 input_zero_point = qinfo[0][1]
1686 if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
1687 error_result = True
1688
1689 info_dict = {
1690 "error_name": error_name,
1691 "error_result": error_result,
1692 "error_reason": error_reason,
1693 "param_reqs": param_reqs
1694 }
1695 return info_dict
1696
1697
1698 @staticmethod
1699 def evOutputZeroPointNotZero(check=False, **kwargs):
1700 error_name = ErrorIf.OutputZeroPointNotZero
1701 param_reqs = {
1702 "rank": None,
1703 "dtype": [DType.INT16, DType.INT32, DType.FLOAT],
1704 "shape": None
1705 }
1706 error_result = False
1707 error_reason = "Output DType not INT8 and zero point not 0"
1708
1709 if check:
1710 output_dtype = kwargs['output_dtype']
1711 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1712 qinfo = kwargs['qinfo'].ints
1713 output_zero_point = qinfo[1][1]
1714 if output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
1715 error_result = True
1716
1717 info_dict = {
1718 "error_name": error_name,
1719 "error_result": error_result,
1720 "error_reason": error_reason,
1721 "param_reqs": param_reqs
1722 }
1723 return info_dict
1724
Matthew Haddond6ce7252021-09-29 15:35:44 +01001725 @staticmethod
1726 def evAxisSmallerZero(check=False, **kwargs):
1727 error_name = ErrorIf.AxisSmallerZero
1728 param_reqs = {"rank": None, "dtype": None, "shape": None}
1729 error_result = False
1730 error_reason = "Axis smaller than zero"
1731
1732 if check:
1733 axis = kwargs['axis']
1734 if axis < 0:
1735 error_result = True
1736
1737 info_dict = {
1738 "error_name": error_name,
1739 "error_result": error_result,
1740 "error_reason": error_reason,
1741 "param_reqs": param_reqs
1742 }
1743 return info_dict
1744
1745
1746 @staticmethod
1747 def evAxisLargerRank(check=False, **kwargs):
1748 error_name = ErrorIf.AxisLargerRank
1749 param_reqs = {"rank": None, "dtype": None, "shape": None}
1750 error_result = False
1751 error_reason = "Axis larger than rank"
1752
1753 if check:
1754 axis = kwargs['axis']
1755 shape = kwargs['input_shape']
1756 if axis > len(shape):
1757 error_result = True
1758
1759 info_dict = {
1760 "error_name": error_name,
1761 "error_result": error_result,
1762 "error_reason": error_reason,
1763 "param_reqs": param_reqs
1764 }
1765 return info_dict
1766
1767
1768 @staticmethod
1769 def evShapeOfAxisNotOne(check=False, **kwargs):
1770 error_name = ErrorIf.ShapeOfAxisNotOne
1771 param_reqs = {"rank": None, "dtype": None, "shape": None}
1772 error_result = False
1773 error_reason = "shape[axis] is not equal to 1"
1774
1775 if check:
1776 axis = kwargs['axis']
1777 shape = kwargs['output_shape']
1778 if (0 <= axis < len(shape)) and shape[axis] != 1:
1779 error_result = True
1780
1781 info_dict = {
1782 "error_name": error_name,
1783 "error_result": error_result,
1784 "error_reason": error_reason,
1785 "param_reqs": param_reqs
1786 }
1787 return info_dict
1788
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001789
Matthew Haddonb724efc2021-08-25 16:40:29 +01001790class TosaInvalidValidator:
1791
1792 @staticmethod
1793 def ivWrongDataTypeOrModeResize(**kwargs):
1794 input_dtype = kwargs["input_dtype"]
1795 args = kwargs["args"]
1796 mode = args[0]
1797 stride = args[1]
1798 stride_fp = args[4]
1799 output_dtype = args[8]
1800
1801 if mode == ResizeMode.BILINEAR:
1802 # Invalid output data type / Invalid input datatype
1803 return (
1804 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
1805 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
1806 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
1807 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
1808 )
1809 elif mode == ResizeMode.NEAREST:
1810 # Invalid output data type / Invalid input datatype
1811 return (
1812 (input_dtype != output_dtype) or
1813 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
1814 )
1815 else:
1816 # Invalid resize mode
1817 return True
1818
1819 @staticmethod
1820 def ivBadStride(**kwargs):
1821 input_dtype = kwargs["input_dtype"]
1822 args = kwargs["args"]
1823 stride_x = args[1][0]
1824 stride_y = args[1][1]
1825 stride_fp_x = args[4][0]
1826 stride_fp_y = args[4][1]
1827
1828 if input_dtype == DType.FLOAT:
1829 if stride_fp_x <= 0 or stride_fp_y <= 0:
1830 # Negative or zero stride
1831 return True
1832 else:
1833 if stride_x <= 0 or stride_y <= 0:
1834 # Negative or zero stride
1835 return True
1836 return False
1837
1838
Matthew Haddonb724efc2021-08-25 16:40:29 +01001839 @staticmethod
1840 def ivHeightWidthSmallerZero(**kwargs):
1841 opName = kwargs['opName']
1842
1843 inputShapes = kwargs['shapeList']
1844 input = inputShapes[0]
1845 if not opName.endswith("pool2d"):
1846 filter = inputShapes[1]
1847
1848 args = kwargs['args']
1849 strides = args[0]
1850 padding = args[1]
1851 dilations = args[2]
1852 if opName.endswith("pool2d"):
1853 kernel = args[2]
1854
1855 if opName.startswith('conv2d'):
1856 h = (
1857 input[1]
1858 - filter[1]
1859 - (filter[1] - 1) * (dilations[0] - 1)
1860 + padding[0]
1861 + padding[1]
1862 ) // strides[0] + 1
1863
1864 w = (
1865 input[2]
1866 - filter[2]
1867 - (filter[2] - 1) * (dilations[1] - 1)
1868 + padding[2]
1869 + padding[3]
1870 ) // strides[1] + 1
1871 elif opName.startswith("depthwise_conv2d"):
1872 h = (
1873 input[1]
1874 - filter[0]
1875 - (filter[0] - 1) * (dilations[0] - 1)
1876 + padding[0]
1877 + padding[1]
1878 ) // strides[0] + 1
1879
1880 w = (
1881 input[2]
1882 - filter[1]
1883 - (filter[1] - 1) * (dilations[1] - 1)
1884 + padding[2]
1885 + padding[3]
1886 ) // strides[1] + 1
1887 elif opName.endswith("pool2d"):
1888 h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
1889 w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
1890 else:
1891 assert False, "Unrecognized Op"
1892
1893 if h <= 0 or w <= 0:
1894 # Invalid parameter combination
1895 return True
1896 return False
1897
1898 @staticmethod
1899 def ivNonPositiveOutputShape(**kwargs):
1900 args = kwargs['args']
1901 output_shape = args[3]
1902 if output_shape[1] <= 0 or output_shape[2] <= 0:
1903 # Negative output shape
1904 return True
1905 return False
1906
1907
Kevin Cheng550ccc52021-03-03 11:21:43 -08001908
Eric Kunzee5e26762020-10-13 16:11:07 -07001909class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001910 # Maximum rank of tensor supported by test generator.
1911 TOSA_TENSOR_MAX_RANK = 6
1912
Eric Kunzee5e26762020-10-13 16:11:07 -07001913 def __init__(self, args):
1914 self.args = args
1915 self.basePath = args.output_dir
1916 self.random_seed = args.random_seed
1917 self.ser = None
1918 self.rng = np.random.default_rng(self.random_seed)
1919 self.createDynamicOpLists()
1920 self.initOpListDefaults()
1921 self.quantGen = TosaQuantGen()
1922 # Force makeShape to do a specific starting shape
1923 self.targetted_shape = None
1924
1925 def createSerializer(self, opName, testPath):
1926 self.testPath = os.path.join(opName, testPath)
1927
1928 fullPath = os.path.join(self.basePath, self.testPath)
1929 os.makedirs(fullPath, exist_ok=True)
1930 self.ser = ts.TosaSerializer(fullPath)
1931
1932 def getSerializer(self):
1933 return self.ser
1934
1935 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001936 with open(
1937 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
1938 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07001939 fd.write(self.ser.serialize())
1940
Kevin Cheng550ccc52021-03-03 11:21:43 -08001941 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
1942 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07001943
Matthew Haddon74567092021-07-16 15:38:20 +01001944 def resetRNG(self, seed=None):
1945 if seed == None:
1946 seed = self.random_seed + 1
1947 self.rng = np.random.default_rng(seed)
1948
Eric Kunzee5e26762020-10-13 16:11:07 -07001949 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07001950 if dtype == DType.BOOL:
1951 np_dt = np.bool
1952 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07001953 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07001954 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07001955 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001956 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001957 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
1958 elif dtype == DType.UINT8:
1959 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001960 elif dtype == DType.INT16:
1961 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
1962 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001963 return np.int32(
1964 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
1965 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001966 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001967 return np.int64(
1968 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
1969 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001970 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001971 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001972 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001973 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001974
Kevin Cheng989cb052021-04-28 16:29:44 -07001975 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07001976 placeholders = []
1977
Kevin Cheng989cb052021-04-28 16:29:44 -07001978 assert len(shape_list) == len(dtype_list)
1979
1980 for idx, shape in enumerate(shape_list):
1981 arr = self.getRandTensor(shape, dtype_list[idx])
1982 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001983
1984 return placeholders
1985
Kevin Cheng989cb052021-04-28 16:29:44 -07001986 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07001987 consts = []
1988
Kevin Cheng989cb052021-04-28 16:29:44 -07001989 assert len(shape_list) == len(dtype_list)
1990
1991 for idx, shape in enumerate(shape_list):
1992 arr = self.getRandTensor(shape, dtype_list[idx])
1993 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001994
1995 return consts
1996
1997 def makeShape(self, rank):
1998 if self.targetted_shape:
1999 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002000 return np.int32(
2001 self.rng.integers(
2002 low=self.args.tensor_shape_range[0],
2003 high=self.args.tensor_shape_range[1],
2004 size=rank,
2005 )
2006 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002007
2008 def setTargetShape(self, shape):
2009 self.targetted_shape = shape
2010
2011 def randInt(self, low=0, high=256):
2012 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
2013
2014 def getRandNumberDType(self, dtype):
2015 if dtype == DType.FLOAT:
2016 return self.rng.random()
2017 elif dtype == DType.BOOL:
2018 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07002019 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07002020 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07002021 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07002022 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01002023 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07002024 elif dtype == DType.INT16:
2025 low, high = (-32768, 32768)
2026 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002027 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07002028 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002029 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07002030 # Special size
2031 return np.int64(self.rng.integers(low, high, size=1))[0]
2032 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002033 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002034
2035 return np.int32(self.rng.integers(low, high, size=1))[0]
2036
2037 def shapeStr(self, shape):
2038
2039 sStr = []
2040 # Convert to strings
2041 for i in shape:
2042 sStr.append(str(i))
2043
Kevin Cheng550ccc52021-03-03 11:21:43 -08002044 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002045
2046 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07002047 if isinstance(t, list):
2048 assert len(t) >= 2
2049 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002050 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002051 if t == DType.BOOL:
2052 return "b"
2053 elif t == DType.INT4:
2054 return "i4"
2055 elif t == DType.INT8:
2056 return "i8"
2057 elif t == DType.UINT8:
2058 return "u8"
2059 elif t == DType.INT16:
2060 return "i16"
2061 elif t == DType.INT32:
2062 return "i32"
2063 elif t == DType.INT48:
2064 return "i48"
2065 elif t == DType.FLOAT:
2066 return "float"
2067 else:
2068 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002069
2070 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002071 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08002072 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07002073 return 4
2074 elif t == DType.INT8:
2075 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08002076 elif t == DType.UINT8:
2077 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07002078 elif t == DType.INT16:
2079 return 16
2080 elif t == DType.INT32:
2081 return 32
2082 elif t == DType.INT48:
2083 return 48
2084 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002085 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002086
2087 # Argument generators
2088 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
2089 # Where the string descriptor is used to generate the test name and
2090 # The build_fcn_arg_list is expanded and passed to the operator test
2091 # build function
2092
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002093 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
2094 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
2095
Matthew Haddon848efb42021-09-09 12:30:53 +01002096 # build_placeholder returns an int, ABS/other ops does not
2097 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002098 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
2099 return result_tens
2100 elif op['op'] == Op.IDENTITY:
2101 self.ser.addOperator(op['op'], a.name, result_tens.name, None, qinfo)
2102 return result_tens
2103
2104 # Ensure new output type has correct qinfo
2105 if error_name == ErrorIf.WrongOutputType:
2106 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
2107 qinfo = ts.TosaSerializerQuantInfo()
2108 qinfo.UnaryQuantInfo(
2109 TosaQuantGen.getQinfo(self, a.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
2110 )
2111
2112 # Invalidate Input/Output list for error if checks.
2113 input_list = [a.name]
2114 output_list = [result_tens.name]
2115 pCount, cCount = op["operands"]
2116 num_operands = pCount + cCount
2117 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2118
2119 TosaErrorValidator.evValidateErrorIfs(
2120 self.ser,
2121 validator_fcns,
2122 error_name,
2123 op=op,
2124 input_dtype=a.dtype,
2125 output_dtype=result_tens.dtype,
2126 qinfo = qinfo,
2127 result_tensor = result_tens,
2128 input_list=input_list,
2129 output_list=output_list,
2130 num_operands=num_operands,
2131 )
2132
2133 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002134 return result_tens
2135
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002136 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
2137 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
2138
2139
2140 # Invalidate Input/Output list for error if checks.
2141 input_list = [a.name, b.name]
2142 output_list = [result_tens.name]
2143 pCount, cCount = op["operands"]
2144 num_operands = pCount + cCount
2145 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2146
2147 TosaErrorValidator.evValidateErrorIfs(
2148 self.ser,
2149 validator_fcns,
2150 error_name,
2151 op=op,
2152 input1 = a,
2153 input2 = b,
2154 input_dtype = a.dtype,
2155 output_dtype = result_tens.dtype,
2156 result_tensor = result_tens,
2157 input_list=input_list,
2158 output_list=output_list,
2159 num_operands=num_operands,
2160 )
2161
2162 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07002163 return result_tens
2164
2165 def build_binary_nonbroadcast(self, op, a, b):
2166 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002167 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002168 return result_tens
2169
Kevin Chengaee1fac2020-11-11 13:54:06 -08002170 def build_arithmetic_right_shift(self, op, a, b, round):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002171 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002172
2173 attr = ts.TosaSerializerAttribute()
2174 attr.ArithmeticRightShiftAttribute(round)
2175
Matthew Haddon848efb42021-09-09 12:30:53 +01002176 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002177 return result_tens
2178
2179 def build_mul(self, op, a, b, shift):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002180 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Eric Kunzee5e26762020-10-13 16:11:07 -07002181
2182 # Special for multiply:
2183 # Force the result to INT32 for INT types
2184 if a.dtype != DType.FLOAT:
2185 result_tens.setDtype(DType.INT32)
2186
Kevin Chengaee1fac2020-11-11 13:54:06 -08002187 attr = ts.TosaSerializerAttribute()
2188 attr.MulAttribute(shift)
2189
Matthew Haddon848efb42021-09-09 12:30:53 +01002190 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002191 return result_tens
2192
2193 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002194 # Constant size depending on type, random values
2195 if a.dtype == DType.INT16:
Kevin Chengacb550f2021-06-29 15:32:19 -07002196 table_dtype = DType.INT16
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002197 table_arr = self.getRandTensor([513], table_dtype)
2198 else:
2199 assert a.dtype == DType.INT8
2200 table_dtype = DType.INT8
2201 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002202
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002203 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
2204 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01002205 self.ser.addOperator(op['op'], [a.name, table_tens.name], [result_tens.name], None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002206
2207 return result_tens
2208
2209 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07002210 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002211 self.ser.addOperator(op['op'], [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002212 return result_tens
2213
2214 def build_comparison(self, op, a, b):
2215 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002216 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002217 return result_tens
2218
2219 def build_argmax(self, op, a, axis):
2220 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
2221
2222 attr = ts.TosaSerializerAttribute()
2223 attr.AxisAttribute(axis)
2224
Matthew Haddon848efb42021-09-09 12:30:53 +01002225 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002226 return result_tens
2227
Matthew Haddonb724efc2021-08-25 16:40:29 +01002228 def build_pool2d(self, op, input, stride, pad, kernel, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07002229 result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
2230
2231 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002232 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07002233
Matthew Haddon848efb42021-09-09 12:30:53 +01002234 self.ser.addOperator(op['op'], [input.name], [result_tens.name], attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002235 return result_tens
2236
2237 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002238 assert len(padding) == 4
2239 result_tens = OutputShaper.conv2dOp(
2240 self.ser, ifm, filter, strides, padding, dilations
2241 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002242
2243 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002244 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07002245
Kevin Cheng550ccc52021-03-03 11:21:43 -08002246 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002247 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002248 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002249 return result_tens
2250
Kevin Cheng1533b852021-09-01 12:51:58 -07002251 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
2252 assert len(padding) == 6
2253 result_tens = OutputShaper.conv3dOp(
2254 self.ser, ifm, filter, strides, padding, dilations
2255 )
2256
2257 attr = ts.TosaSerializerAttribute()
2258 attr.ConvAttribute(padding, strides, dilations)
2259
2260 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002261 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng1533b852021-09-01 12:51:58 -07002262 )
2263 return result_tens
2264
Kevin Cheng550ccc52021-03-03 11:21:43 -08002265 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07002266 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002267 ):
2268 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07002269 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
2270
2271 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002272 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002273
Kevin Cheng550ccc52021-03-03 11:21:43 -08002274 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002275 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002276 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002277 return result_tens
2278
Kevin Cheng550ccc52021-03-03 11:21:43 -08002279 def build_depthwise_conv2d(
2280 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
2281 ):
2282 result_tens = OutputShaper.depthwiseConv2dOp(
2283 self.ser, ifm, filter, strides, padding, dilations
2284 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002285
2286 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002287 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07002288
Kevin Cheng550ccc52021-03-03 11:21:43 -08002289 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002290 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002291 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002292 return result_tens
2293
2294 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
2295 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
2296
Kevin Cheng550ccc52021-03-03 11:21:43 -08002297 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002298 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002299 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002300 return result_tens
2301
2302 def build_matmul(self, op, a, b, qinfo):
2303 result_tens = OutputShaper.matmulOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002304 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002305 return result_tens
2306
Matthew Haddond6ce7252021-09-29 15:35:44 +01002307 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
2308 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
2309
2310 # Invalidate Input/Output list for error if checks.
2311 input_list = [a.name]
2312 output_list = [result_tens.name]
2313 pCount, cCount = op["operands"]
2314 num_operands = pCount + cCount
2315 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2316
2317 TosaErrorValidator.evValidateErrorIfs(
2318 self.ser,
2319 validator_fcns,
2320 error_name,
2321 op=op,
2322 axis = axis,
2323 input_shape = a.shape,
2324 output_shape = result_tens.shape,
2325 input_dtype = a.dtype,
2326 output_dtype = result_tens.dtype,
2327 result_tensor = result_tens,
2328 input_list=input_list,
2329 output_list=output_list,
2330 num_operands=num_operands,
2331 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002332
2333 attr = ts.TosaSerializerAttribute()
2334 attr.AxisAttribute(axis)
2335
Matthew Haddond6ce7252021-09-29 15:35:44 +01002336 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002337 return result_tens
2338
2339 def build_clamp(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002340 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002341
2342 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01002343 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07002344
2345 if a.dtype == DType.FLOAT:
2346 attr.ClampAttribute(0, 0, min(v), max(v))
2347 else:
2348 attr.ClampAttribute(min(v), max(v), 0, 0)
2349
Matthew Haddon848efb42021-09-09 12:30:53 +01002350 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002351 return result_tens
2352
2353 def build_leaky_relu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002354 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002355 attr = ts.TosaSerializerAttribute()
2356
2357 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
2358
Matthew Haddon848efb42021-09-09 12:30:53 +01002359 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002360 return result_tens
2361
2362 # Needs an additional type/input
2363 def build_prelu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002364 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002365
Matthew Haddon848efb42021-09-09 12:30:53 +01002366 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002367 return result_tens
2368
Eric Kunzee5e26762020-10-13 16:11:07 -07002369 def build_sigmoid(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002370 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01002371 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002372 return result_tens
2373
2374 def build_tanh(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002375 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01002376 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002377 return result_tens
2378
Matthew Haddon818ab902021-07-27 09:12:49 +01002379 def build_concat(self, op, *a):
Kevin Cheng93a16282021-08-31 16:14:03 -07002380 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01002381
2382 # To store variable length list of input tensors we need to store axis along with it
2383 axis = a[-1]
2384 a = a[:-1]
2385
2386 result_tens = OutputShaper.concatOp(self.ser, axis, *a)
Eric Kunzee5e26762020-10-13 16:11:07 -07002387
2388 attr = ts.TosaSerializerAttribute()
2389 attr.AxisAttribute(axis)
2390
Matthew Haddon818ab902021-07-27 09:12:49 +01002391 input_tensor_names = []
2392 for tensor in a:
2393 input_tensor_names.append(tensor.name)
2394
Matthew Haddon848efb42021-09-09 12:30:53 +01002395 self.ser.addOperator(op['op'], input_tensor_names, [result_tens.name], attr)
2396 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002397
2398 def build_pad(self, op, a, padding, qinfo):
2399 result_tens = OutputShaper.padOp(self.ser, a, padding)
2400
2401 # Need to turn the padding array into a TOSA tensor here.
2402 # This is one of the few tensor operands that does not get
2403 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08002404 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07002405
Kevin Cheng550ccc52021-03-03 11:21:43 -08002406 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002407 op['op'], [a.name, padding_tens.name], [result_tens.name], None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002408 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002409 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002410
2411 def build_reshape(self, op, a, newShape):
2412 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
2413
2414 attr = ts.TosaSerializerAttribute()
2415 attr.ReshapeAttribute(newShape)
2416
Matthew Haddon848efb42021-09-09 12:30:53 +01002417 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002418 return result_tens
2419
2420 def build_reverse(self, op, a, axis):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002421 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002422
2423 attr = ts.TosaSerializerAttribute()
2424 attr.AxisAttribute(axis)
2425
Matthew Haddon848efb42021-09-09 12:30:53 +01002426 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002427 return result_tens
2428
2429 def build_transpose(self, op, a, perms):
2430 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
2431
Kevin Cheng550ccc52021-03-03 11:21:43 -08002432 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07002433
Matthew Haddon848efb42021-09-09 12:30:53 +01002434 self.ser.addOperator(op['op'], [a.name, perms_tens.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002435 return result_tens
2436
2437 def build_slice(self, op, a, begin, size):
2438 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
2439
2440 attr = ts.TosaSerializerAttribute()
2441 attr.SliceAttribute(begin, size)
2442
Matthew Haddon848efb42021-09-09 12:30:53 +01002443 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002444 return result_tens
2445
2446 def build_tile(self, op, a, multiples):
2447 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
2448
2449 attr = ts.TosaSerializerAttribute()
2450 attr.TileAttribute(multiples)
2451
Matthew Haddon848efb42021-09-09 12:30:53 +01002452 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002453 return result_tens
2454
Kevin Cheng77d0f762020-11-24 10:26:32 -08002455 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07002456
2457 # Create a new indicies tensor
2458 # here with data that doesn't exceed the dimensions of the values tensor
2459
Kevin Cheng550ccc52021-03-03 11:21:43 -08002460 K = values.shape[1] # K
2461 W = self.randInt(
2462 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
2463 ) # W
2464 indicies_arr = np.int32(
2465 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
2466 ) # (N, W)
2467 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002468
Kevin Cheng77d0f762020-11-24 10:26:32 -08002469 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07002470
Matthew Haddon848efb42021-09-09 12:30:53 +01002471 self.ser.addOperator(op['op'], [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002472
2473 return result_tens
2474
Kevin Cheng77d0f762020-11-24 10:26:32 -08002475 def build_scatter(self, op, values_in, input):
2476
2477 # Create a new indicies tensor
2478 # here with data that doesn't exceed the dimensions of the values_in tensor
2479
Kevin Cheng550ccc52021-03-03 11:21:43 -08002480 K = values_in.shape[1] # K
2481 W = input.shape[1] # W
2482 indicies_arr = np.int32(
2483 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
2484 ) # (N, W)
2485 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002486
2487 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
2488
Kevin Cheng550ccc52021-03-03 11:21:43 -08002489 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002490 op['op'], [values_in.name, indicies.name, input.name], [result_tens.name]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002491 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08002492
2493 return result_tens
2494
Matthew Haddon848efb42021-09-09 12:30:53 +01002495
Kevin Cheng550ccc52021-03-03 11:21:43 -08002496 def build_resize(
2497 self,
2498 op,
2499 input,
2500 mode,
2501 stride,
2502 offset,
2503 shift,
2504 stride_fp,
2505 offset_fp,
2506 output_dims,
2507 input_dtype,
2508 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01002509 validator_fcns,
2510 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002511 ):
2512 result_tens = OutputShaper.resizeOp(
2513 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002514 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002515 input,
2516 mode,
2517 stride,
2518 offset,
2519 shift,
2520 stride_fp,
2521 offset_fp,
2522 output_dims,
2523 input_dtype,
2524 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01002525 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08002526 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002527
Matthew Haddon848efb42021-09-09 12:30:53 +01002528 # Invalidate Input/Output list for error if checks.
2529 input_list = [input.name]
2530 output_list = [result_tens.name]
2531 pCount, cCount = op["operands"]
2532 num_operands = pCount + cCount
2533 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
Matthew Haddone86fd342021-09-07 16:12:21 +01002534
Matthew Haddon848efb42021-09-09 12:30:53 +01002535 TosaErrorValidator.evValidateErrorIfs(
2536 self.ser,
2537 validator_fcns,
2538 error_name,
2539 op=op,
2540 mode=mode,
2541 shift=shift,
2542 input_dtype=input_dtype,
2543 output_dtype=output_dtype,
2544 input_shape=input,
2545 output_shape=output_dims,
2546 offset=offset,
2547 offset_fp=offset_fp,
2548 stride=stride,
2549 stride_fp=stride_fp,
2550 input_list=input_list,
2551 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002552 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01002553 num_operands=num_operands,
2554 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002555
Eric Kunzee5e26762020-10-13 16:11:07 -07002556 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08002557
Kevin Cheng550ccc52021-03-03 11:21:43 -08002558 attr.ResizeAttribute(
2559 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
2560 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002561
Matthew Haddon848efb42021-09-09 12:30:53 +01002562 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002563 return result_tens
2564
2565 def build_identityn(self, op, val, val2):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002566 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, None)
2567 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, None)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002568 self.ser.addOperator(
2569 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2570 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002571 return result_tens
2572
Kevin Cheng17e92022021-10-01 14:33:33 -07002573 def build_const(self, op, val):
2574 self.ser.addOutputTensor(val)
2575 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07002576
2577 # Type Conversion
2578 def build_cast(self, op, val, out_dtype):
2579 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01002580 self.ser.addOperator(op['op'], [val.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002581 return result_tens
2582
2583 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
2584 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
2585
2586 if per_channel:
2587 nc = val.shape[-1]
2588 else:
2589 nc = 1
2590
2591 in_type_width = self.typeWidth(val.dtype)
2592 out_type_width = self.typeWidth(out_dtype)
2593
Kevin Cheng3a478572021-01-22 17:21:02 -08002594 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002595 input_zp = self.randInt(-128, 128)
2596 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002597 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002598 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07002599 in_type_width = in_type_width + 1
2600 else:
2601 input_zp = 0
2602
Kevin Cheng3a478572021-01-22 17:21:02 -08002603 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002604 output_zp = self.randInt(-128, 128)
2605 out_type_width = out_type_width + 1
2606 elif out_dtype == DType.UINT8:
2607 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07002608 out_type_width = out_type_width + 1
2609 else:
2610 output_zp = 0
2611
2612 # Calculate scale based on:
2613 # scale = a *(2^output_width)/(2^input_width))
2614
2615 a = np.float32(self.rng.random(size=[nc]))
2616 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2617
2618 if scale32:
2619 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002620 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002621 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2622 else:
2623 # Cap the scaling at 2^15 - 1 for scale16
2624 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2625
Kevin Cheng550ccc52021-03-03 11:21:43 -08002626 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002627
2628 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2629 shift_arr = np.int32(np.zeros(shape=[nc]))
2630
2631 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002632 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2633 scale_arr[i], scale32
2634 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002635
Kevin Cheng550ccc52021-03-03 11:21:43 -08002636 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07002637
2638 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002639 attr.RescaleAttribute(
2640 input_zp,
2641 output_zp,
2642 multiplier_arr,
2643 shift_arr,
2644 scale32,
2645 double_round,
2646 per_channel,
2647 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002648
Matthew Haddon848efb42021-09-09 12:30:53 +01002649 self.ser.addOperator(op['op'], [val.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002650 return result_tens
2651
2652 def build_cond_if_const(self, op, then_tens, else_tens, cond):
2653 # For cond_if with constants, we're supplied with then/else tensors that we ignore
2654 # (except for the generated shap) and the condition. Build Then/Else blocks
2655 # and fill them with const nodes for the body.
2656
2657 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002658 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07002659
2660 # Make then/else tensors
2661 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01002662 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2663 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002664
2665 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002666 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002667
2668 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002669 then_block = "THEN_BLOCK"
2670 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002671 attr = ts.TosaSerializerAttribute()
2672 attr.CondIfAttribute(then_block, else_block)
2673
2674 # Finally, build the op and the two blocks
Matthew Haddon848efb42021-09-09 12:30:53 +01002675 self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002676
2677 self.ser.startBasicBlock(then_block)
2678 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002679 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002680 self.ser.addOutputTensor(then_tens)
2681
2682 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002683 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002684 self.ser.addOutputTensor(else_tens)
2685
2686 return result_tens
2687
2688 def build_cond_if_binary(self, op, a, b, cond):
2689 # For cond_if with a binary op in the then/else blocks, take a and b and
2690 # alternately add or subtract them based on the condition
2691
2692 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002693 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07002694
Kevin Cheng550ccc52021-03-03 11:21:43 -08002695 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002696
2697 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002698 then_block = "THEN_BLOCK"
2699 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002700 attr = ts.TosaSerializerAttribute()
2701 attr.CondIfAttribute(then_block, else_block)
2702
2703 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002704 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002705 op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002706 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002707
2708 self.ser.startBasicBlock(then_block)
2709 self.ser.addInputTensor(a)
2710 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002711 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002712 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
2713
2714 self.ser.startBasicBlock(else_block)
2715 self.ser.addInputTensor(a)
2716 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002717 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002718 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
2719
2720 return result_tens
2721
2722 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002723 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002724
Kevin Cheng550ccc52021-03-03 11:21:43 -08002725 cond_block = "COND_BLOCK"
2726 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002727
2728 attr = ts.TosaSerializerAttribute()
2729 attr.WhileLoopAttribute(cond_block, body_block)
2730
2731 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002732 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002733 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002734 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002735
2736 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002737 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2738 a_out = self.ser.addIntermediate(a.shape, a.dtype)
2739 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002740
2741 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002742 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002743 op['op'],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002744 [iter.name, a.name, acc.name],
2745 [iter_out.name, a_out.name, acc_out.name],
2746 attr,
2747 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002748 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002749
2750 # COND block (input: iter, output: cond_tens )
2751 self.ser.startBasicBlock(cond_block)
2752 self.ser.addInputTensor(iter)
2753 self.ser.addInputTensor(a)
2754 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002755 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
2756 cond_tens = self.ser.addOutput([], DType.BOOL)
2757 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002758
2759 # BODY block (input: a, acc, iter, output: a, acc, iter)
2760 # Note that local intermediate tensors need to be declared here for the outputs
2761 self.ser.startBasicBlock(body_block)
2762 self.ser.addInputTensor(iter)
2763 self.ser.addInputTensor(a)
2764 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002765 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
2766 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2767 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002768 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2769 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2770 self.ser.addOutputTensor(iter_body_out)
2771 self.ser.addOutputTensor(a)
2772 self.ser.addOutputTensor(acc_body_out)
2773
2774 return acc_out
2775
Matthew Haddon1c00b712021-10-01 15:51:03 +01002776 def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
2777 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2778 default_test_rank_range = range(1, 5)
2779 if not shapeFilter:
2780 shapeFilter = [None]
2781
2782 # Calculate the filters based on what is requested and what the operator allows
2783 rmin, rmax = op["rank"]
2784 if rankFilter is not None:
2785 cleanRankFilter = []
2786 # Ensure rankFilter values are allowed by operator
2787 for rank in rankFilter:
2788 if rank >= rmin and rank <= rmax:
2789 cleanRankFilter.append(rank)
2790 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002791 # Ensure default behaviour is bounded by default range or by operator,
2792 # whichever is the smaller range of ranks.
2793 opRankRange = range(rmin, rmax + 1)
2794 cleanRankFilter = opRankRange if len(opRankRange) <= len(default_test_rank_range) else default_test_rank_range
Matthew Haddon1c00b712021-10-01 15:51:03 +01002795 else:
2796 cleanRankFilter = range(rmin, rmax + 1)
2797
2798 dtypes = op["types"]
2799 if dtypeFilter is not None:
2800 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002801 # Create list of operator dtypes filtered by requested dtypes
2802 for dtype in dtypes:
2803 if dtype in dtypeFilter or (isinstance(dtype, list) and dtype[0] in dtypeFilter):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002804 cleanDtypeFilter.append(dtype)
2805 else:
2806 cleanDtypeFilter = dtypes
2807
2808 if testType == 'positive':
2809 filterDict = {
2810 'shapeFilter': shapeFilter,
2811 'rankFilter': cleanRankFilter,
2812 'dtypeFilter': cleanDtypeFilter
2813 }
2814 return filterDict
2815 elif testType == 'negative':
2816 validator_info = validator(check=False, op=op)
2817 error_arguments = validator_info['param_reqs']
2818
2819 #Set parameters as required
2820 if error_arguments['rank'] != None:
2821 rankFilter = error_arguments['rank']
2822 else:
2823 rankFilter = cleanRankFilter
2824
2825 if error_arguments['dtype'] != None:
2826 dtypeFilter = error_arguments['dtype']
2827 else:
2828 dtypeFilter = cleanDtypeFilter
2829
2830 if error_arguments['shape'] != None:
2831 shapeFilter = error_arguments['shape']
2832 else:
2833 shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small
2834
2835 filterDict = {
2836 'shapeFilter': shapeFilter,
2837 'rankFilter': rankFilter,
2838 'dtypeFilter': dtypeFilter
2839 }
2840 return filterDict
2841
2842
Kevin Cheng550ccc52021-03-03 11:21:43 -08002843 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01002844 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08002845 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002846
2847 try:
2848 op = self.TOSA_OP_LIST[opName]
2849 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002850 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002851
2852 # Initialize a new random number generator
2853 self.rng = np.random.default_rng(self.random_seed)
2854
Kevin Cheng550ccc52021-03-03 11:21:43 -08002855 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002856
Eric Kunzee5e26762020-10-13 16:11:07 -07002857 # Test list consists of a tuple of:
2858 # (opName, testNameStr, dtype, shapeList, argumentsList)
2859 testList = []
Matthew Haddon1c00b712021-10-01 15:51:03 +01002860 if testType == 'negative' and "error_if_validators" in op:
2861 error_if_validators = op["error_if_validators"]
2862 else:
2863 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002864
Matthew Haddon1c00b712021-10-01 15:51:03 +01002865 for validator in error_if_validators:
2866 if validator is not None:
2867 error_name = validator(check=False, op=op)['error_name']
2868 #print("error_name: ", error_name)
2869 else:
2870 error_name = None
2871
2872 filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator)
2873 cleanRankFilter = filterDict['rankFilter']
2874 cleanDtypeFilter = filterDict['dtypeFilter']
2875 cleanShapeFilter = filterDict['shapeFilter']
2876 #print(f"Filters: S {shapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
2877
2878 for r in cleanRankFilter:
Kevin Cheng1533b852021-09-01 12:51:58 -07002879 if opName.startswith("conv3d"):
2880 assert r == 5, "conv3d test must have input rank == 5"
Matthew Haddon1c00b712021-10-01 15:51:03 +01002881 for t in cleanDtypeFilter:
2882 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002883 # Filter out by rank
2884 if shape is not None and len(shape) != r:
2885 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002886 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002887 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002888
Matthew Haddon74567092021-07-16 15:38:20 +01002889 shapeStr = self.shapeStr(shapeList[0])
2890 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002891
Matthew Haddon74567092021-07-16 15:38:20 +01002892 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2893 argList = []
2894 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002895 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002896 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002897 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002898
Matthew Haddon74567092021-07-16 15:38:20 +01002899 for argStr, args in argList:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002900 if testType == 'positive':
2901 if argStr:
2902 testStr = "{}_{}_{}_{}".format(
2903 opName, shapeStr, typeStr, argStr
2904 )
2905 else:
2906 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
2907 elif testType == 'negative':
Matthew Haddone86fd342021-09-07 16:12:21 +01002908 if argStr:
2909 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2910 opName, error_name, shapeStr, typeStr, argStr
2911 )
2912 else:
2913 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002914
2915 testList.append((opName, testStr, t, error_name, shapeList, args))
2916
2917 if testType == 'positive':
2918 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2919 if "invalid_test_validators" in op:
2920 invalid_test_validators = op["invalid_test_validators"]
2921 clean_testList = []
2922 for test in testList:
2923 for validator_fcn in invalid_test_validators:
2924 remove_test = False
2925 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
2926 remove_test = True
2927 if not remove_test:
2928 clean_testList.append(test)
2929 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002930
2931 return testList
2932
Matthew Haddone86fd342021-09-07 16:12:21 +01002933
2934 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07002935 try:
2936 op = self.TOSA_OP_LIST[opName]
2937 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002938 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002939
2940 # Create a serializer
2941 self.createSerializer(opName, testStr)
2942
Kevin Cheng550ccc52021-03-03 11:21:43 -08002943 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002944 if "error_if_validators" in op:
2945 error_if_validators = op["error_if_validators"]
2946 else:
2947 error_if_validators = None
2948
Kevin Cheng550ccc52021-03-03 11:21:43 -08002949 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002950 num_operands = pCount + cCount
2951
2952 if isinstance(dtype_or_dtypeList, list):
2953 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002954 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002955 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002956 else:
2957 dtypeList = [dtype_or_dtypeList] * (num_operands)
2958
Kevin Cheng93a16282021-08-31 16:14:03 -07002959 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002960 assert (
2961 len(shapeList) == num_operands
2962 ), "shapeList length {} must match number of operands {}".format(
2963 len(shapeList), num_operands
2964 )
2965 assert (
2966 len(dtypeList) == num_operands
2967 ), "dtypeList length {} must match number of operands {}".format(
2968 len(dtypeList), num_operands
2969 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002970
2971 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002972 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002973 except KeyError:
2974 qgen = None
2975
2976 # Build the random tensor operands and the test
2977 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002978
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002979 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002980
2981 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002982 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002983 else:
2984 qinfo = None
2985
2986 try:
2987 if error_if_validators is None:
2988 if qinfo is not None:
2989 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2990 else:
2991 resultName = build_fcn(self, op, *tens, *testArgs)
2992 else:
2993 if qinfo is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002994 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name, qinfo)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002995 else:
2996 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name)
2997 except TypeError as e:
2998 print(
2999 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
3000 build_fcn, tens, testArgs
3001 )
3002 )
3003 raise e
3004
3005 if resultName is None:
3006 print("Invalid ERROR_IF tests created")
3007
3008 # Save the serialized test
3009 self.serialize("test")
3010
3011
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003012 def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003013 pCount, cCount = op["operands"]
3014
3015 tens = []
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003016 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name == None:
Jeremy Johnsonef509a42021-09-07 13:59:47 +01003017 # Make sure the operation does not cause value saturation - where
3018 # the number wraps due to limited number of bits to store the answer
3019 assert (
3020 pCount == 2 and cCount == 0
3021 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01003022 placeholders = []
3023 add = (op["op"] == Op.ADD)
3024 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
3025 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
3026 if add:
3027 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
3028 else:
3029 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
3030
3031 # Work out the saturation limits
3032 max_i32 = (1 << 31)-1
3033 min_i32 = -(1 << 31)
3034 max_arr = np.full(shapeList[1], max_i32)
3035 min_arr = np.full(shapeList[1], min_i32)
3036
3037 # Find how much values exceed the maximum/minimums
3038 sat_max_arr = np.maximum(res_arr - max_arr, 0)
3039 sat_min_arr = np.minimum(res_arr - min_arr, 0)
3040
3041 if not add:
3042 # Swap saturation values and negate values as we need to perform opposite operations
3043 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
3044
3045 # Create new array of unsaturated values by clipping values as needed
3046 b_unsat_arr = b_arr
3047 if (sat_max_arr != 0).any():
3048 # Clip values that cause saturation
3049 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
3050 # Reduce axes in unsaturated tensor to match original tensor
3051 for axis, dim in enumerate(b_arr.shape):
3052 if dim != b_unsat_arr.shape[axis]:
3053 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
3054 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
3055
3056 if (sat_min_arr != 0).any():
3057 # Clip values that cause saturation
3058 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
3059 # Reduce axes in unsaturated tensor to match original tensor
3060 for axis, dim in enumerate(b_arr.shape):
3061 if dim != b_unsat_arr.shape[axis]:
3062 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
3063 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
3064
3065 placeholders.append(
3066 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
3067 )
3068 placeholders.append(
3069 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
3070 )
3071
3072 tens.extend(placeholders)
3073 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
3074 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003075 assert (
3076 pCount == 2 and cCount == 0
3077 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08003078
3079 placeholders = []
3080 for idx, shape in enumerate(shapeList[:]):
3081 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07003082 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08003083 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07003084 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08003085 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07003086 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08003087 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
3088 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003089 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08003090 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003091 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07003092 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08003093
3094 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01003095 elif op["op"] == Op.SELECT:
3096 # Set datatype of condition tensor to boolean
3097 dtypeList[0] = DType.BOOL
3098 tens.extend(
3099 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
3100 )
3101 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003102 elif op["op"] == Op.INTDIV and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003103 assert (
3104 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01003105 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003106
3107 placeholders = []
3108
Matthew Haddon459443c2021-08-23 16:43:13 +01003109 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003110 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07003111 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003112 while True:
3113 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
3114 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
3115
3116 if (divisor_arr == 0).any():
3117 continue
3118
Kevin Cheng47315e12021-05-13 17:41:28 -07003119 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003120 continue
3121
3122 break
3123
3124 placeholders.append(
3125 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
3126 )
3127 placeholders.append(
3128 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
3129 )
3130
3131 tens.extend(placeholders)
3132 elif op["op"] == Op.MUL:
3133 assert (
3134 pCount == 2 and cCount == 0
3135 ), "Op.MUL must have 2 placeholders, 0 consts"
3136
3137 if dtypeList[0] == DType.FLOAT:
3138 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
3139 else:
3140 placeholders = []
3141
3142 # Make sure multiply result in int32 range
3143 shift = testArgs[0]
3144 if dtypeList[0] == DType.INT8:
3145 num_bits = 8
3146 elif dtypeList[0] == DType.INT16:
3147 num_bits = 16
3148 elif dtypeList[0] == DType.INT32:
3149 num_bits = 32
3150 else:
3151 raise Exception("OpMul: invalid input dtype")
3152
3153 for idx, shape in enumerate(shapeList[:]):
3154 low = -(2 ** (num_bits - 1))
3155 high = (2 ** (num_bits - 1)) - 1
3156
3157 a_arr = np.int32(
3158 self.rng.integers(low=low, high=high, size=shapeList[0])
3159 )
3160 b_arr = np.int32(
3161 self.rng.integers(low=low, high=high, size=shapeList[1])
3162 )
3163
3164 i = 0
3165 while True:
3166
3167 a_arr_64 = a_arr.astype(np.int64)
3168 b_arr_64 = b_arr.astype(np.int64)
3169
3170 if shift > 0:
3171 rounding = 1 << (shift - 1)
3172 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
3173 else:
3174 result_arr = a_arr_64 * b_arr_64
3175
3176 if (result_arr > -(2 ** 31)).all() and (
3177 result_arr <= ((2 ** 31) - 1)
3178 ).all():
3179 break
3180
3181 i = i + 1
3182 a_arr = a_arr // 2
3183 b_arr = b_arr // 2
3184
3185 placeholders.append(
3186 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
3187 )
3188 placeholders.append(
3189 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
3190 )
3191
3192 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01003193 elif op["op"] == Op.CONCAT:
3194 count = len(shapeList) - self.args.num_const_inputs_concat
3195 if count < 1:
3196 count = 1
3197 if self.args.num_const_inputs_concat == 0:
3198 count = len(shapeList)
3199
3200 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
3201 tens.extend(
3202 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
3203 )
3204 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08003205 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07003206 tens.extend(
3207 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
3208 )
3209 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07003210
Matthew Haddon1c00b712021-10-01 15:51:03 +01003211 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07003212
3213 def createDynamicOpLists(self):
3214
3215 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07003216 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003217
Kevin Cheng1533b852021-09-01 12:51:58 -07003218 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003219 testName = "conv2d_{}x{}".format(k[0], k[1])
3220 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3221 self.TOSA_OP_LIST[testName]["filter"] = k
3222 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003223
Kevin Cheng550ccc52021-03-03 11:21:43 -08003224 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3225 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3226 "depthwise_conv2d_TEMPLATE"
3227 ].copy()
3228 self.TOSA_OP_LIST[testName]["filter"] = k
3229 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003230
Kevin Cheng550ccc52021-03-03 11:21:43 -08003231 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3232 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3233 "transpose_conv2d_TEMPLATE"
3234 ].copy()
3235 self.TOSA_OP_LIST[testName]["filter"] = k
3236 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003237
Kevin Cheng1533b852021-09-01 12:51:58 -07003238 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3239 for k in KERNELS_3D:
3240 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3241 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3242 self.TOSA_OP_LIST[testName]["filter"] = k
3243 self.TOSA_OP_LIST[testName]["template"] = False
3244
Eric Kunzee5e26762020-10-13 16:11:07 -07003245 # Delete any templates after having created any dynamic ops
3246 # This is a two-pass operation because it's bad practice to delete
3247 # keys from dictionaries while iterating
3248 keyList = []
3249 for k in self.TOSA_OP_LIST:
3250 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003251 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07003252 keyList.append(k)
3253 continue
3254 except KeyError:
3255 pass
3256
3257 for k in keyList:
3258 del self.TOSA_OP_LIST[k]
3259
3260 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003261 """Fill in default fields for ops if they aren't already specified.
3262 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003263 for op in self.TOSA_OP_LIST:
3264
3265 # Required fields
3266 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003267 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003268 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003269 raise Exception(
3270 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3271 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003272
3273 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003274 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003275 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003276 raise Exception(
3277 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3278 op
3279 )
3280 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003281
3282 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003283 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003284 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003285 raise Exception(
3286 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3287 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003288
3289 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003290 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003291 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003292 raise Exception(
3293 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3294 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003295
3296 # Put in default rank range, if missing
3297 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003298 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003299 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003300 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003301
3302 # Tensor operator list
3303 # 'op': op name
3304 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003305 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3306 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003307 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3308 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08003309 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07003310
Kevin Cheng550ccc52021-03-03 11:21:43 -08003311 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
3312 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003313
Kevin Cheng550ccc52021-03-03 11:21:43 -08003314 TYPE_BOOL = [DType.BOOL]
3315 TYPE_FI32 = [DType.FLOAT, DType.INT32]
3316 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
3317 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003318
Kevin Cheng550ccc52021-03-03 11:21:43 -08003319 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07003320
Kevin Cheng1533b852021-09-01 12:51:58 -07003321 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003322 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003323 [DType.INT8, DType.INT8, DType.INT32],
3324 [DType.INT16, DType.INT8, DType.INT48],
3325 DType.FLOAT,
3326 ]
3327
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003328 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003329
3330 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003331 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003332 "argmax": {
3333 "op": Op.ARGMAX,
3334 "operands": (1, 0),
3335 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3336 "types": TYPE_NARROW_INT_FP,
3337 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003338 "avg_pool2d": {
3339 "op": Op.AVG_POOL2D,
3340 "operands": (1, 0),
3341 "rank": (4, 4),
3342 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
3343 "qgen": TosaQuantGen.qgUnary,
3344 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003345 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
Jared Smolens573ecd42021-03-04 15:24:10 -08003346 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003347 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003348 "conv2d_TEMPLATE": {
3349 "op": Op.CONV2D,
3350 "operands": (1, 2),
3351 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01003352 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003353 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003354 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003355 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003356 "template": True,
3357 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003358 # Templated operator. Filled in by createDynamicOpLists
3359 "conv3d_TEMPLATE": {
3360 "op": Op.CONV3D,
3361 "operands": (1, 2),
3362 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01003363 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07003364 "qgen": TosaQuantGen.qgConv,
3365 "types": TYPE_CONV,
3366 "template": True,
3367 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003368 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003369 "depthwise_conv2d_TEMPLATE": {
3370 "op": Op.DEPTHWISE_CONV2D,
3371 "operands": (1, 2),
3372 "filter": [1, 1],
3373 "rank": (4, 4),
3374 "build_fcn": (
3375 build_depthwise_conv2d,
3376 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01003377 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003378 ),
3379 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003380 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003381 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003382 "template": True,
3383 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003384 "fully_connected": {
3385 "op": Op.FULLY_CONNECTED,
3386 "operands": (1, 2),
3387 "rank": (2, 2),
3388 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
3389 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003390 "types": TYPE_CONV,
Jared Smolens573ecd42021-03-04 15:24:10 -08003391 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003392 "matmul": {
3393 "op": Op.MATMUL,
3394 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003395 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08003396 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
3397 "qgen": TosaQuantGen.qgMatmul,
3398 "types": TYPE_NARROW_INT_FP,
3399 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003400 "max_pool2d": {
3401 "op": Op.MAX_POOL2D,
3402 "operands": (1, 0),
3403 "rank": (4, 4),
3404 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
3405 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003406 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
Jared Smolens573ecd42021-03-04 15:24:10 -08003407 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003408 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003409 "transpose_conv2d_TEMPLATE": {
3410 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003411 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003412 "rank": (4, 4),
3413 "build_fcn": (
3414 build_transpose_conv2d,
3415 TosaTensorGen.tgTransposeConv2D,
3416 TosaArgGen.agTransposeConv2D,
3417 ),
3418 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003419 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003420 "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003421 "template": True,
3422 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003423 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003424 "clamp": {
3425 "op": Op.CLAMP,
3426 "operands": (1, 0),
3427 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
3428 "types": TYPE_NARROW_INT_FP,
3429 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003430 "sigmoid": {
3431 "op": Op.SIGMOID,
3432 "operands": (1, 0),
3433 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
3434 "types": TYPE_FP,
3435 },
3436 "tanh": {
3437 "op": Op.TANH,
3438 "operands": (1, 0),
3439 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
3440 "types": TYPE_FP,
3441 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003442 # Elementwise Binary Operators
3443 "add": {
3444 "op": Op.ADD,
3445 "operands": (2, 0),
3446 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3447 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003448 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3449 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003450 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003451 "arithmetic_right_shift": {
3452 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3453 "operands": (2, 0),
3454 "build_fcn": (
3455 build_arithmetic_right_shift,
3456 TosaTensorGen.tgBroadcastFuzz,
3457 TosaArgGen.agArithmeticRightShift,
3458 ),
3459 "types": TYPE_INT,
3460 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003461 "bitwise_and": {
3462 "op": Op.BITWISE_AND,
3463 "operands": (2, 0),
3464 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3465 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003466 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3467 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003468 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003469 "bitwise_or": {
3470 "op": Op.BITWISE_OR,
3471 "operands": (2, 0),
3472 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3473 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003474 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3475 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003476 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003477 "bitwise_xor": {
3478 "op": Op.BITWISE_XOR,
3479 "operands": (2, 0),
3480 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3481 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003482 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3483 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003484 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003485 "intdiv": {
3486 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003487 "operands": (2, 0),
3488 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3489 "types": [DType.INT32],
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003490 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3491 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003492 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003493 "logical_and": {
3494 "op": Op.LOGICAL_AND,
3495 "operands": (2, 0),
3496 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3497 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003498 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3499 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003500 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003501 "logical_left_shift": {
3502 "op": Op.LOGICAL_LEFT_SHIFT,
3503 "operands": (2, 0),
3504 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3505 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003506 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3507 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003508 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003509 "logical_right_shift": {
3510 "op": Op.LOGICAL_RIGHT_SHIFT,
3511 "operands": (2, 0),
3512 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3513 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003514 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3515 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003516 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003517 "logical_or": {
3518 "op": Op.LOGICAL_OR,
3519 "operands": (2, 0),
3520 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3521 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003522 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3523 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003524 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003525 "logical_xor": {
3526 "op": Op.LOGICAL_XOR,
3527 "operands": (2, 0),
3528 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3529 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003530 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3531 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003532 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003533 "maximum": {
3534 "op": Op.MAXIMUM,
3535 "operands": (2, 0),
3536 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3537 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003538 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3539 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003540 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003541 "minimum": {
3542 "op": Op.MINIMUM,
3543 "operands": (2, 0),
3544 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3545 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003546 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3547 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003548 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003549 "mul": {
3550 "op": Op.MUL,
3551 "operands": (2, 0),
3552 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
3553 "types": TYPE_INT_FP,
3554 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003555 "pow": {
3556 "op": Op.POW,
3557 "operands": (2, 0),
3558 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
3559 "types": TYPE_FP,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003560 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3561 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003562 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003563 "sub": {
3564 "op": Op.SUB,
3565 "operands": (2, 0),
3566 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3567 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003568 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3569 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003570 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003571 "table": {
3572 "op": Op.TABLE,
3573 # Use the automatic generation functions to create the input array
3574 # but create the table tensor in the build function, as it may be
3575 # a different type from the input
3576 "operands": (1, 0),
3577 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003578 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08003579 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003580 # Elementwise Unary operators
3581 "abs": {
3582 "op": Op.ABS,
3583 "operands": (1, 0),
3584 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3585 "types": TYPE_FI32,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003586 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3587 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003588 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003589 "bitwise_not": {
3590 "op": Op.BITWISE_NOT,
3591 "operands": (1, 0),
3592 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3593 "types": TYPE_INT,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003594 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3595 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003596 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003597 "ceil": {
3598 "op": Op.CEIL,
3599 "operands": (1, 0),
3600 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3601 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003602 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3603 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003604 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003605 "clz": {
3606 "op": Op.CLZ,
3607 "operands": (1, 0),
3608 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3609 "types": [DType.INT32],
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003610 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3611 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003612 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003613 "exp": {
3614 "op": Op.EXP,
3615 "operands": (1, 0),
3616 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3617 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003618 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3619 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003620 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003621 "floor": {
3622 "op": Op.FLOOR,
3623 "operands": (1, 0),
3624 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3625 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003626 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3627 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003628 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003629 "log": {
3630 "op": Op.LOG,
3631 "operands": (1, 0),
3632 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3633 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003634 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3635 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003636 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003637 "logical_not": {
3638 "op": Op.LOGICAL_NOT,
3639 "operands": (1, 0),
3640 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3641 "types": TYPE_BOOL,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003642 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3643 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003644 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003645 "negate": {
3646 "op": Op.NEGATE,
3647 "operands": (1, 0),
3648 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3649 "qgen": TosaQuantGen.qgUnary,
3650 "types": TYPE_INT_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003651 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
3652 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
3653 TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003654 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003655 "reciprocal": {
3656 "op": Op.RECIPROCAL,
3657 "operands": (1, 0),
3658 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3659 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003660 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3661 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003662 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003663 "rsqrt": {
3664 "op": Op.RSQRT,
3665 "operands": (1, 0),
3666 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3667 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003668 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3669 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003670 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003671 # Elementwise Ternary operators
3672 "select": {
3673 "op": Op.SELECT,
3674 "operands": (3, 0),
3675 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
3676 "types": TYPE_FIB,
3677 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003678 # Comparison operators
3679 "equal": {
3680 "op": Op.EQUAL,
3681 "operands": (2, 0),
3682 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3683 "types": TYPE_FI32,
3684 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003685 "greater_equal": {
3686 "op": Op.GREATER_EQUAL,
3687 "operands": (2, 0),
3688 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3689 "types": TYPE_FI32,
3690 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003691 "greater": {
3692 "op": Op.GREATER,
3693 "operands": (2, 0),
3694 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3695 "types": TYPE_FI32,
3696 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003697 # Reduction operators
3698 "reduce_all": {
3699 "op": Op.REDUCE_ALL,
3700 "operands": (1, 0),
3701 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3702 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01003703 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
3704 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
3705 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003706 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003707 "reduce_any": {
3708 "op": Op.REDUCE_ANY,
3709 "operands": (1, 0),
3710 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3711 "types": TYPE_BOOL,
Matthew Haddond6ce7252021-09-29 15:35:44 +01003712 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
3713 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
3714 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003715 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003716 "reduce_max": {
3717 "op": Op.REDUCE_MAX,
3718 "operands": (1, 0),
3719 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3720 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01003721 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
3722 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
3723 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003724 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003725 "reduce_min": {
3726 "op": Op.REDUCE_MAX,
3727 "operands": (1, 0),
3728 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3729 "types": TYPE_INT_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01003730 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
3731 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
3732 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003733 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003734 "reduce_product": {
3735 "op": Op.REDUCE_PRODUCT,
3736 "operands": (1, 0),
3737 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3738 "types": TYPE_FP,
Matthew Haddond6ce7252021-09-29 15:35:44 +01003739 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
3740 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
3741 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003742 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003743 "reduce_sum": {
3744 "op": Op.REDUCE_SUM,
3745 "operands": (1, 0),
3746 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3747 "types": TYPE_FI32,
Matthew Haddond6ce7252021-09-29 15:35:44 +01003748 "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
3749 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
3750 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003751 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003752 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003753 "concat": {
3754 "op": Op.CONCAT,
3755 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01003756 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003757 "types": TYPE_FIB,
3758 },
3759 "pad": {
3760 "op": Op.PAD,
3761 "operands": (1, 0),
3762 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
3763 "qgen": TosaQuantGen.qgPad,
3764 "types": TYPE_FIB,
3765 },
3766 "reshape": {
3767 "op": Op.RESHAPE,
3768 "operands": (1, 0),
3769 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
3770 "types": TYPE_FIB,
3771 },
3772 "reverse": {
3773 "op": Op.REVERSE,
3774 "operands": (1, 0),
3775 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3776 "types": TYPE_FIB,
3777 },
3778 "slice": {
3779 "op": Op.SLICE,
3780 "operands": (1, 0),
3781 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
3782 "types": TYPE_FIB,
3783 },
3784 "tile": {
3785 "op": Op.TILE,
3786 "operands": (1, 0),
3787 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
3788 "types": TYPE_FIB,
3789 },
3790 "transpose": {
3791 "op": Op.TRANSPOSE,
3792 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003793 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003794 "build_fcn": (
3795 build_transpose,
3796 TosaTensorGen.tgBasic,
3797 TosaArgGen.agTranspose,
3798 ),
3799 "types": TYPE_FIB,
3800 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003801 # Data nodes
3802 "const": {
3803 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003804 "operands": (0, 1),
3805 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08003806 "types": TYPE_FIB,
3807 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003808 "identity": {
3809 "op": Op.IDENTITY,
3810 "operands": (1, 0),
3811 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3812 "types": TYPE_FIB,
3813 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003814 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003815 "gather": {
3816 "op": Op.GATHER,
3817 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3818 "operands": (1, 0),
3819 "rank": (3, 3),
3820 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
3821 "types": TYPE_INT_FP,
3822 },
3823 "scatter": {
3824 "op": Op.SCATTER,
3825 # Only specify 'values_in' tensor here.
3826 #'indices' and 'input' are generated in op building stage
3827 "operands": (2, 0),
3828 "rank": (3, 3),
3829 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
3830 "types": TYPE_INT_FP,
3831 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003832 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003833 "resize": {
3834 "op": Op.RESIZE,
3835 "operands": (1, 0),
3836 "rank": (4, 4),
3837 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
3838 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01003839 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
3840 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
3841 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
Matthew Haddon848efb42021-09-09 12:30:53 +01003842 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003843 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
3844 TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003845 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003846 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003847 "cast": {
3848 "op": Op.CAST,
3849 "operands": (1, 0),
3850 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
3851 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
3852 },
3853 "rescale": {
3854 "op": Op.RESCALE,
3855 "operands": (1, 0),
3856 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003857 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Kevin Cheng550ccc52021-03-03 11:21:43 -08003858 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003859 # Custom
3860 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003861 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003862 # Two varients of cond_if, one that generates one of two constant tensors (no
3863 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3864 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003865 "cond_if_const": {
3866 "op": Op.COND_IF,
3867 "operands": (0, 2),
3868 "build_fcn": (
3869 build_cond_if_const,
3870 TosaTensorGen.tgBasic,
3871 TosaArgGen.agCondIf,
3872 ),
3873 "types": [DType.BOOL],
3874 },
3875 "cond_if_binary": {
3876 "op": Op.COND_IF,
3877 "operands": (2, 0),
3878 "build_fcn": (
3879 build_cond_if_binary,
3880 TosaTensorGen.tgBasic,
3881 TosaArgGen.agCondIf,
3882 ),
3883 "types": TYPE_FI32,
3884 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003885 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003886 "while_loop": {
3887 "op": Op.WHILE_LOOP,
3888 "operands": (0, 1),
3889 "build_fcn": (
3890 build_while_loop,
3891 TosaTensorGen.tgBasic,
3892 TosaArgGen.agWhileLoop,
3893 ),
3894 "types": [DType.INT32],
3895 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003896 }
3897
Kevin Cheng550ccc52021-03-03 11:21:43 -08003898
Eric Kunzee5e26762020-10-13 16:11:07 -07003899class OutputShaper:
3900 # Methods in this class compute the expected output shape and datatype
3901 # for common classes of operations
3902 def __init__(self):
3903 pass
3904
3905 # These methods return arguments that can be used for
3906 # creating a new output tensor
3907 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003908 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
3909 if error_name != ErrorIf.RankMismatch:
3910 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003911 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003912
3913 shape = []
3914 for i in range(len(a.shape)):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003915 if a.shape[i] == 1 and error_name == None:
Eric Kunzee5e26762020-10-13 16:11:07 -07003916 shape.append(b.shape[i])
3917 else:
3918 shape.append(a.shape[i])
3919
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003920 if error_name == ErrorIf.WrongOutputType:
3921 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
3922 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3923 outputDType = rng.choice(wrong_dtypes)
3924 else:
3925 outputDType = a.dtype
3926
3927 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003928
3929 @staticmethod
3930 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003931 assert len(a.shape) == len(b.shape)
3932 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003933
3934 shape = []
3935 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003936 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07003937 shape.append(a.shape[i])
3938
Kevin Cheng550ccc52021-03-03 11:21:43 -08003939 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003940
3941 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003942 def unaryOp(ser, rng, a, error_name=None):
3943 if error_name == ErrorIf.WrongOutputType:
3944 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
3945 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3946 outputDType = rng.choice(wrong_dtypes)
3947 else:
3948 outputDType = a.dtype
3949
3950 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003951
3952 @staticmethod
3953 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003954 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
3955 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003956
3957 shape = []
3958 for i in range(len(a.shape)):
3959 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
3960
Kevin Cheng550ccc52021-03-03 11:21:43 -08003961 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003962
3963 @staticmethod
3964 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003965 assert len(a.shape) == len(b.shape)
3966 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003967
3968 # Do broadcast
3969 shape = []
3970 for i in range(len(a.shape)):
3971 if a.shape[i] == 1:
3972 shape.append(b.shape[i])
3973 else:
3974 shape.append(a.shape[i])
3975
3976 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08003977 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07003978
3979 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01003980 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003981 shape = a.shape.copy()
Matthew Haddond6ce7252021-09-29 15:35:44 +01003982 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ShapeOfAxisNotOne]:
3983 shape[axis] = 1
3984 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
3985 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07003986
Matthew Haddond6ce7252021-09-29 15:35:44 +01003987 if error_name == ErrorIf.WrongOutputType:
3988 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
3989 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3990 outputDType = rng.choice(wrong_dtypes)
3991 else:
3992 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003993
Matthew Haddond6ce7252021-09-29 15:35:44 +01003994 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003995
3996 @staticmethod
3997 def argmaxOp(ser, a, axis):
3998 shape = a.shape.copy()
3999 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004000 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07004001
4002 @staticmethod
4003 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
4004
4005 # IFM: NHWC
4006 # Filter: OHWI
4007 # OFM: NHWC
4008
4009 if len(padding) == 2:
4010 # Expand padding to 4 parameters in the case of transpose_conv2d
4011 # From H,W to T,B,L,R
4012 padding = [padding[0], padding[0], padding[1], padding[1]]
4013
Kevin Cheng550ccc52021-03-03 11:21:43 -08004014 h = (
4015 ifm.shape[1]
4016 - filter.shape[1]
4017 - (filter.shape[1] - 1) * (dilations[0] - 1)
4018 + padding[0]
4019 + padding[1]
4020 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004021
Kevin Cheng550ccc52021-03-03 11:21:43 -08004022 w = (
4023 ifm.shape[2]
4024 - filter.shape[2]
4025 - (filter.shape[2] - 1) * (dilations[1] - 1)
4026 + padding[2]
4027 + padding[3]
4028 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004029
Eric Kunzee5e26762020-10-13 16:11:07 -07004030 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4031
Kevin Cheng3a478572021-01-22 17:21:02 -08004032 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004033 out_dtype = DType.INT32
4034 elif ifm.dtype == DType.INT16:
4035 out_dtype = DType.INT48
4036 elif ifm.dtype == DType.FLOAT:
4037 out_dtype = DType.FLOAT
4038 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004039 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004040
Kevin Cheng550ccc52021-03-03 11:21:43 -08004041 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004042
4043 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -07004044 def conv3dOp(ser, ifm, filter, strides, padding, dilations):
4045
4046 # IFM: NDHWC
4047 # Filter: ODHWI
4048 # OFM: NDHWC
4049
4050 d = (
4051 ifm.shape[1]
4052 - filter.shape[1]
4053 - (filter.shape[1] - 1) * (dilations[0] - 1)
4054 + padding[0]
4055 + padding[1]
4056 ) // strides[0] + 1
4057
4058 h = (
4059 ifm.shape[2]
4060 - filter.shape[2]
4061 - (filter.shape[2] - 1) * (dilations[1] - 1)
4062 + padding[2]
4063 + padding[3]
4064 ) // strides[1] + 1
4065
4066 w = (
4067 ifm.shape[3]
4068 - filter.shape[3]
4069 - (filter.shape[3] - 1) * (dilations[2] - 1)
4070 + padding[4]
4071 + padding[5]
4072 ) // strides[2] + 1
4073
4074 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4075
4076 if ifm.dtype == DType.INT8:
4077 out_dtype = DType.INT32
4078 elif ifm.dtype == DType.INT16:
4079 out_dtype = DType.INT48
4080 elif ifm.dtype == DType.FLOAT:
4081 out_dtype = DType.FLOAT
4082 else:
4083 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
4084
4085 return ser.addOutput(ofm_shape, out_dtype)
4086
4087 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -07004088 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
4089 # IFM: NHWC
4090 # Filter: HWCM
4091 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08004092 h = (
4093 ifm.shape[1]
4094 - filter.shape[0]
4095 - (filter.shape[0] - 1) * (dilations[0] - 1)
4096 + padding[0]
4097 + padding[1]
4098 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004099
Kevin Cheng550ccc52021-03-03 11:21:43 -08004100 w = (
4101 ifm.shape[2]
4102 - filter.shape[1]
4103 - (filter.shape[1] - 1) * (dilations[1] - 1)
4104 + padding[2]
4105 + padding[3]
4106 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004107
Eric Kunzee5e26762020-10-13 16:11:07 -07004108 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4109
Kevin Cheng3a478572021-01-22 17:21:02 -08004110 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004111 out_dtype = DType.INT32
4112 elif ifm.dtype == DType.INT16:
4113 out_dtype = DType.INT48
4114 elif ifm.dtype == DType.FLOAT:
4115 out_dtype = DType.FLOAT
4116 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004117 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004118
Kevin Cheng550ccc52021-03-03 11:21:43 -08004119 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004120
4121 @staticmethod
4122 def pool2dOp(ser, ifm, kernel, stride, pad):
4123 # input: NHWC
4124 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
4125 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
4126
Eric Kunzee5e26762020-10-13 16:11:07 -07004127 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004128 return ser.addOutput(ofm_shape, ifm.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004129
4130 @staticmethod
4131 def fullyConnectedOp(ser, input, filter):
4132 # input: N, IC
4133 # filter: OC, IC
4134 # output: N, OC
4135
4136 output_shape = [input.shape[0], filter.shape[0]]
4137
Kevin Cheng3a478572021-01-22 17:21:02 -08004138 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004139 out_dtype = DType.INT32
4140 elif input.dtype == DType.INT16:
4141 out_dtype = DType.INT48
4142 elif input.dtype == DType.FLOAT:
4143 out_dtype = DType.FLOAT
4144 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004145 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004146
Kevin Cheng550ccc52021-03-03 11:21:43 -08004147 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004148
4149 @staticmethod
4150 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004151 # a: N, H, C
4152 # b: N, C, W
4153 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004154
Kevin Cheng2d60f002021-06-09 14:18:32 -07004155 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004156
Kevin Cheng3a478572021-01-22 17:21:02 -08004157 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004158 out_dtype = DType.INT32
4159 elif a.dtype == DType.INT16:
4160 out_dtype = DType.INT48
4161 elif a.dtype == DType.FLOAT:
4162 out_dtype = DType.FLOAT
4163 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004164 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004165
Kevin Cheng550ccc52021-03-03 11:21:43 -08004166 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004167
4168 @staticmethod
Matthew Haddon818ab902021-07-27 09:12:49 +01004169 def concatOp(ser, axis, *a):
4170 input1 = a[0]
4171 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004172
Matthew Haddon818ab902021-07-27 09:12:49 +01004173 output_shape = input1.shape.copy()
Eric Kunzee5e26762020-10-13 16:11:07 -07004174
Matthew Haddon818ab902021-07-27 09:12:49 +01004175 output_shape[axis] = input1.shape[axis]
4176
4177 for tensor in remaining_inputs:
4178 output_shape[axis] += tensor.shape[axis]
4179
4180 return ser.addOutput(output_shape, input1.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004181
4182 @staticmethod
4183 def padOp(ser, a, padding):
4184
4185 output_shape = a.shape.copy()
4186
4187 for i in range(len(output_shape)):
4188 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4189
Kevin Cheng550ccc52021-03-03 11:21:43 -08004190 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004191
4192 @staticmethod
4193 def reshapeOp(ser, a, shape):
4194 output_shape = shape.copy()
4195
4196 totalElements = 1
4197 for i in a.shape:
4198 totalElements *= i
4199
4200 # If there are any -1 elements, figure out what that dimension must be
4201 totalOutputElements = 1
4202 for i in output_shape:
4203 if i != -1:
4204 totalOutputElements *= i
4205
4206 # And fill it in
4207 for i in range(len(output_shape)):
4208 if output_shape[i] == -1:
4209 output_shape[i] = totalElements // totalOutputElements
4210
Kevin Cheng550ccc52021-03-03 11:21:43 -08004211 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004212
4213 @staticmethod
4214 def sliceOp(ser, a, begin, size):
4215
4216 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004217 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004218
4219 @staticmethod
4220 def tileOp(ser, a, multiples):
4221
4222 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004223 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004224
4225 for i in range(len(output_shape)):
4226 output_shape[i] = a.shape[i] * multiples[i]
4227
Kevin Cheng550ccc52021-03-03 11:21:43 -08004228 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004229
4230 @staticmethod
4231 def transposeOp(ser, a, perms):
4232 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004233 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004234
4235 for i in range(len(output_shape)):
4236 output_shape[i] = a.shape[perms[i]]
4237
Kevin Cheng550ccc52021-03-03 11:21:43 -08004238 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004239
4240 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08004241 def gatherOp(ser, values, indices):
4242 assert len(values.shape) == 3
4243 assert len(indices.shape) == 2
4244 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004245
Kevin Cheng77d0f762020-11-24 10:26:32 -08004246 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4247
Kevin Cheng550ccc52021-03-03 11:21:43 -08004248 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004249
4250 @staticmethod
4251 def scatterOp(ser, values_in, indices, input):
4252 assert len(values_in.shape) == 3
4253 assert len(indices.shape) == 2
4254 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004255 assert values_in.shape[0] == indices.shape[0] # N
4256 assert input.shape[1] == indices.shape[1] # W
4257 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004258
4259 output_shape = values_in.shape
4260
Kevin Cheng550ccc52021-03-03 11:21:43 -08004261 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004262
4263 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004264 def tableOp(ser, input, table_dtype):
4265 # Same shape as the input, but dtype dependent on table dtype
4266 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
4267 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
4268 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004269
4270 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004271 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004272 serializer,
4273 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004274 input,
4275 mode,
4276 stride,
4277 offset,
4278 shift,
4279 stride_fp,
4280 offset_fp,
4281 output_dims,
4282 input_dtype,
4283 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01004284 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08004285 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01004286 if error_name == ErrorIf.WrongRank:
4287 output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
4288 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004289 if error_name == ErrorIf.BatchMismatch:
4290 output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
4291 elif error_name == ErrorIf.ChannelMismatch:
4292 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
4293 else:
4294 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004295
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004296 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004297
4298 @staticmethod
4299 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004300 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004301
4302 @staticmethod
4303 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08004304 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004305 out_dtype = DType.INT32
4306 elif ifm.dtype == DType.INT16:
4307 out_dtype = DType.INT48
4308 elif ifm.dtype == DType.FLOAT:
4309 out_dtype = DType.FLOAT
4310 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004311 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004312
Kevin Cheng550ccc52021-03-03 11:21:43 -08004313 return ser.addOutput(output_shape, out_dtype)