blob: 2478331cbaa69f97766b0f2cc9e1a94fd082990e [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
48DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080049Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070050ResizeMode = tosa.ResizeMode.ResizeMode()
51
52class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
54
Eric Kunzee5e26762020-10-13 16:11:07 -070055 def __init__(self):
56 pass
57
58 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010059 def getQinfo(testGen, dtype, error_name=None):
60
Les Bell30e46802021-07-23 09:43:31 +010061 if dtype == DType.INT8:
62 return testGen.randInt(-128, 128)
Matthew Haddone4ecdb22021-09-28 11:38:21 +010063 elif dtype == DType.UINT8:
Les Bell30e46802021-07-23 09:43:31 +010064 return testGen.randInt(0, 256)
Matthew Haddone4ecdb22021-09-28 11:38:21 +010065 elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
66 zero_point = testGen.randInt(-128, 128)
67 if zero_point == 0:
68 zero_point = 1
69 return zero_point
Les Bell30e46802021-07-23 09:43:31 +010070 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070071
72 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010073 def qgUnary(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070074 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone4ecdb22021-09-28 11:38:21 +010075 if error_name == ErrorIf.InputZeroPointNotZero:
76 qinfo.UnaryQuantInfo(
77 TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype)
78 )
79 elif error_name == ErrorIf.OutputZeroPointNotZero:
80 qinfo.UnaryQuantInfo(
81 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype, error_name)
82 )
83 else:
84 qinfo.UnaryQuantInfo(
85 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
86 )
Eric Kunzee5e26762020-10-13 16:11:07 -070087 return qinfo
88
89 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +010090 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -070091 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010092 if isinstance(dtype_or_dtypeList, list):
93 # a list of [input, weights, accumulator] dtypes
94 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070095 else:
Les Bell30e46802021-07-23 09:43:31 +010096 # an int, [input, weights, accumulator] dtypes are the same
97 dtypeList = [dtype_or_dtypeList] * 3
98 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
99 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
100 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -0700101 return qinfo
102
103 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100104 def qgMatmul(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700105 qinfo = ts.TosaSerializerQuantInfo()
Kevin Chengacb550f2021-06-29 15:32:19 -0700106 qinfo.MatMulQuantInfo(
107 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
108 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700109 return qinfo
110
111 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100112 def qgPad(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700113 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +0100114 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700115 return qinfo
116
117 @staticmethod
118 def computeMultiplierAndShift(scaleFp, scale32):
119 # Derived from computeMultiplierAndShiftTosaScale32
120 # Provide a floating-point scaling factor and the scale32 parameter
121 # to compute the multiplier and shift
122
123 if scale32:
124 scaleBits = 31
125 else:
126 scaleBits = 15
127
128 m, shift = math.frexp(scaleFp)
129
130 if scaleFp < 0.0:
131 m = -m
132
133 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800134 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700135
136 if multiplier == (1 << scaleBits):
137 multiplier = multiplier // 2
138 shift = shift + 1
139
140 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100141 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
142
143 # Adjust multiplier such that shift is in allowed value range.
144 if shift == 0:
145 multiplier = multiplier // 4
146 shift = shift + 2
147 elif shift == 1:
148 multiplier = multiplier // 2
149 shift = shift + 1
150 elif shift == 63:
151 multiplier = multiplier * 2
152 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700153
Kevin Cheng550ccc52021-03-03 11:21:43 -0800154 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100155 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700156
157 return multiplier, shift
158
159
Kevin Cheng550ccc52021-03-03 11:21:43 -0800160class TosaTensorGen:
161 """Tensor generators create a shape list for the placeholder and const tensor
162 data operands for the operator. The actual random data is generated separately for each test."""
163
Eric Kunzee5e26762020-10-13 16:11:07 -0700164 def __init__(self):
165 pass
166
167 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100168 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800169 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700170 shape = testGen.makeShape(rank)
171
172 shape_list = []
173 for i in range(pl + const):
174 shape_list.append(shape.copy())
175
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100176 if error_name == ErrorIf.RankMismatch:
177 if rank == 1 and i != 1:
178 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
179 elif i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
181
Eric Kunzee5e26762020-10-13 16:11:07 -0700182 return shape_list
183
184 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100185 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800186 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700187
Matthew Haddon848efb42021-09-09 12:30:53 +0100188 if error_name != ErrorIf.WrongRank:
189 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700190
191 shape = testGen.makeShape(rank)
192
193 # Constrict the batch size?
194 if testGen.args.max_batch_size:
195 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
196
197 shape_list = []
198 for i in range(pl + const):
199 shape_list.append(shape.copy())
200
201 return shape_list
202
203 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100204 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800205 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800206
Kevin Cheng550ccc52021-03-03 11:21:43 -0800207 assert pl == 2
208 assert const == 0
209 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800210
211 values_in_shape = testGen.makeShape(rank)
212
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100213 # ignore max batch size if target shape is set
214 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800215 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
216
Kevin Cheng550ccc52021-03-03 11:21:43 -0800217 W = testGen.randInt(
218 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
219 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100220 # Constrict W if one dimension is too large to keep tensor size reasonable
221 if max(values_in_shape) > 5000:
222 W = testGen.randInt(0, 16)
223
Kevin Cheng77d0f762020-11-24 10:26:32 -0800224 input_shape = [values_in_shape[0], W, values_in_shape[2]]
225
226 shape_list = []
227 shape_list.append(values_in_shape.copy())
228 shape_list.append(input_shape.copy())
229
230 return shape_list
231
232 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100233 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700234 shape = testGen.makeShape(rank)
235
Kevin Cheng550ccc52021-03-03 11:21:43 -0800236 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700237
238 shape_list = []
239
240 # Choose one of the inputs to broadcast
241 bcast_idx = testGen.randInt(0, pl + const)
242 for i in range(pl + const):
243 shape_bcast = shape.copy()
244
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100245 if error_name == ErrorIf.RankMismatch:
246 bcast_idx = -1 # Turn off broadcast because we are not testing it
247 if rank == 1 and i != 1:
248 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
249 elif i != 1:
250 shape_bcast = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
251
Eric Kunzee5e26762020-10-13 16:11:07 -0700252 # If the chosen input, pick a random index to broadcast
253 if i == bcast_idx:
254 fuzz_idx = testGen.randInt(0, rank)
255 shape_bcast[fuzz_idx] = 1
256
257 shape_list.append(shape_bcast)
258
259 return shape_list
260
261 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100262 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800263 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700264
Kevin Cheng550ccc52021-03-03 11:21:43 -0800265 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700266
267 # IFM dimensions are NHWC
268 ifm_shape = testGen.makeShape(rank)
269
270 # Constrict the batch size?
271 if testGen.args.max_batch_size:
272 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
273
274 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800275 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700276
277 # Generate a random OFM depth
278 ofm_depth = testGen.makeShape(1)[0]
279
280 # The filter dimensions are OHWI
281 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
282
283 # The bias is OC
284 bias_shape = np.asarray([ofm_depth])
285
286 return [ifm_shape, filter_shape, bias_shape]
287
288 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100289 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700290 pl, const = op["operands"]
291
292 assert rank == 5
293
294 # IFM dimensions are NDHWC
295 ifm_shape = testGen.makeShape(rank)
296
297 # Constrict the batch size?
298 if testGen.args.max_batch_size:
299 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
300
301 # Get the filter depth/height/width from the operator parameters
302 filter_dhw = op["filter"]
303
304 # Generate a random OFM channel
305 ofm_channel = testGen.makeShape(1)[0]
306
307 # The filter dimensions are ODHWI
308 filter_shape = np.asarray(
309 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
310 )
311
312 # The bias is OC
313 bias_shape = np.asarray([ofm_channel])
314
315 return [ifm_shape, filter_shape, bias_shape]
316
317 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100318 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800319 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700320
Kevin Cheng550ccc52021-03-03 11:21:43 -0800321 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700322
323 # IFM dimensions are NHWC
324 ifm_shape = testGen.makeShape(rank)
325
326 # Constrict the batch size?
327 if testGen.args.max_batch_size:
328 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
329
330 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800331 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700332
333 # Generate a random OFM depth
334 ofm_depth = testGen.makeShape(1)[0]
335
336 # The filter dimensions are OHWI
337 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
338
Kevin Cheng989cb052021-04-28 16:29:44 -0700339 # The bias is OC
340 bias_shape = np.asarray([ofm_depth])
341
342 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700343
344 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100345 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800346 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700347
Kevin Cheng550ccc52021-03-03 11:21:43 -0800348 assert rank == 4
349 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700350
351 # IFM dimensions are NHWC
352 ifm_shape = testGen.makeShape(rank)
353
354 # Constrict the batch size?
355 if testGen.args.max_batch_size:
356 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
357
358 # Get the filter height/width from the operator parameters
359 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800360 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700361
362 # Generate a random OFM depth, but don't let it get too big because
363 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800364 filter_m = (
365 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
366 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700367
368 # The filter dimensions are HWCM
369 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
370
371 # The bias is M * C
372 bias_shape = np.asarray([ifm_shape[3] * filter_m])
373
374 return [ifm_shape, filter_shape, bias_shape]
375
376 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100377 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800378 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700379
Kevin Cheng550ccc52021-03-03 11:21:43 -0800380 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700381
382 input_shape = testGen.makeShape(rank)
Kevin Chengacb550f2021-06-29 15:32:19 -0700383 filter_oc = testGen.rng.integers(
384 low=testGen.args.tensor_shape_range[0],
385 high=testGen.args.tensor_shape_range[1],
386 size=1,
387 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700388 filter_shape = np.asarray([filter_oc, input_shape[1]])
389
390 bias_shape = np.asarray([filter_oc])
391
392 return [input_shape, filter_shape, bias_shape]
393
394 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100395 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800396 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700397
Kevin Cheng2d60f002021-06-09 14:18:32 -0700398 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800399 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700400
401 a_shape = testGen.makeShape(rank)
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100402 # Get a random number for b_oc even if target shape is defined
403 b_oc = np.int32(
404 testGen.rng.integers(
405 low=testGen.args.tensor_shape_range[0],
406 high=testGen.args.tensor_shape_range[1],
407 size=1,
408 )
409 )[0]
410 # If N or H is large let b_oc be 1 to reduce output tensor size
411 if max(a_shape) > 1000:
412 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700413
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100414 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700415 return [a_shape, b_shape]
416
Matthew Haddon818ab902021-07-27 09:12:49 +0100417 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100418 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100419 pl, const = opName["operands"]
420 shape = testGen.makeShape(rank)
421
422 # Create extra tensors to concat.
423 # Take into account value of pl when getting maximum number of concats
424 num_tensors = testGen.randInt(0, 4)
425 shape_list = []
426 for i in range(pl + const + num_tensors):
427 shape_list.append(shape.copy())
428
429 return shape_list
430
431 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100432 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100433 # Split concat shape along axis to allow for multiple const inputs
434 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100435 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddon818ab902021-07-27 09:12:49 +0100436 return shapeList
437
Jeremy Johnson960985a2021-10-06 10:58:14 +0100438 # Create copy of shape we are going to split (so we don't alter shapeList)
439 shape = shapeList[0].copy()
440 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100441 new_shapeList = [shape.copy()]
442 length_on_axis = shape[axis]
443 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700444 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100445 # Calculate split on axis and remaining value
446 split_shape_val = int(shape[axis] / 2)
447 remaining_length = remaining_length - split_shape_val
448
449 # Append new shape, and set remaining shape
450 shape[axis] = split_shape_val
451 new_shapeList.append(shape.copy())
452 shape[axis] = remaining_length
453 if i == len(shapeList) - 3:
454 new_shapeList.append(shape.copy())
455
456 return new_shapeList
457
458
Eric Kunzee5e26762020-10-13 16:11:07 -0700459class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800460 """Argument generators create exhaustive or random lists of attributes for operators that take
461 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
462 tuples where the descriptive_name is appended to the test name and the arglist is expanded
463 as arguments to the operator build function."""
464
Eric Kunzee5e26762020-10-13 16:11:07 -0700465 def __init__(self):
466 pass
467
468 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100469 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800470 """A trivial argument generator for operators that don't take any
471 non-tensor arguments"""
472 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700473
474 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100475 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800476 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700477 axes = []
478
479 shape = shapeList[0]
480
481 for a in range(0, len(shape)):
Matthew Haddon43e37192021-07-09 14:13:02 +0100482 axes.append(("axis{}".format(a), [a]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700483 return axes
484
485 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100486 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700487 arg_list = []
488
489 ifm_shape = shapeList[0]
490 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100491 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
492 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700493
Les Bell7aa69f42021-09-20 10:44:07 +0100494 # Check the rank
495 rank = 5 if opName.startswith("conv3d") else 4
496 assert len(ifm_shape) == rank
497 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700498
Les Bell7aa69f42021-09-20 10:44:07 +0100499 # kernel rank omits batch and channels
500 k_rank = rank - 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700501
Les Bell7aa69f42021-09-20 10:44:07 +0100502 # Generate comprehensive argument lists
503 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
504 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
505 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
506 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
507 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
508 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700509
Les Bell7aa69f42021-09-20 10:44:07 +0100510 # add some oversize argument values
511 if max(ifm_shape) < 64:
512 bigPadding = 9
513 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
514 bigStride = 8
515 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
516 bigDilation = 7
517 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100518
519 # There are too many parameter combinations, so generate them sparsely
Les Bell7aa69f42021-09-20 10:44:07 +0100520 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
521 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
522 if sparsity < 13:
523 sparsity = 1
524 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
525 sparsity += 1
Les Bellf414b3c2021-09-06 11:29:46 +0100526 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100527 for s in sorted(list(strides)):
528 for p in sorted(list(paddings)):
529 for d in sorted(list(dilations)):
530 if (n % sparsity == 0
531 # padding must not exceed the kernel size ?
532 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
533 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
534 # the padded shape must exceed the kernel size
535 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
536 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
537 # the padded shape must exceed the dilation
538 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
539 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
540 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100541 arg_list.append(
542 (
543 "st{}_pad{}_dilat{}".format(
544 "".join([str(x) for x in s]),
545 "".join([str(x) for x in p]),
546 "".join([str(x) for x in d]),
547 ),
548 [s, p, d],
549 )
550 )
551 n += 1
552
Kevin Cheng1533b852021-09-01 12:51:58 -0700553 return arg_list
554
555 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100556 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700557 arg_list = []
558
559 ifm_shape = shapeList[0]
560 filter_shape = shapeList[1]
561
562 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800563 assert len(ifm_shape) == 4
564 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700565
Les Bell7aa69f42021-09-20 10:44:07 +0100566 # Generate comprehensive argument lists
567 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
568 paddings = {x for x in itertools.product(*([p_vals] * 2))}
569 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
570 strides = {x for x in itertools.product(*([s_vals] * 2))}
571 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
572 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700573
Les Bell7aa69f42021-09-20 10:44:07 +0100574 # add some oversize argument values
575 if max(ifm_shape) < 64:
576 bigPadding = 9
577 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
578 bigStride = 8
579 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
580 bigDilation = 7
581 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700582
Les Bell7aa69f42021-09-20 10:44:07 +0100583 # There are too many parameter combinations, so generate them sparsely
584 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
585 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
586 if sparsity < 13:
587 sparsity = 1
588 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
589 sparsity += 1
590 n = 0
591 for s in sorted(list(strides)):
592 for p in sorted(list(paddings)):
593 for d in sorted(list(dilations)):
594 if n % sparsity == 0:
595 # Determine the output shape
596 oh = (
597 ifm_shape[1]
598 - filter_shape[1]
599 - (filter_shape[1] - 1) * (d[0] - 1)
600 + 2 * p[0]
601 ) // s[0] + 1
602 ow = (
603 ifm_shape[2]
604 - filter_shape[2]
605 - (filter_shape[2] - 1) * (d[1] - 1)
606 + 2 * p[1]
607 ) // s[1] + 1
608 os = [ifm_shape[0], oh, ow, filter_shape[0]]
609 arg_list.append(
610 (
611 "st{}_pad{}_dilat{}_os{}".format(
612 "".join([str(x) for x in s]),
613 "".join([str(x) for x in p]),
614 "".join([str(x) for x in d]),
615 "x".join([str(x) for x in os]),
616 ),
617 [s, p, d, os],
618 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800619 )
Les Bell7aa69f42021-09-20 10:44:07 +0100620 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700621
622 return arg_list
623
624 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100625 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700626 arg_list = []
627 rank = len(shapeList[0])
628
Les Bell7ffccce2021-07-28 15:37:02 +0100629 # Exhaustively test combinations of padding on each side of each dimension
630 # - the range of padding values is defined by pad_min and pad_max
631 # - for padding >9, the name format needs to be more distinctive
632 pad_min, pad_max = 0, 1
633 pad_values = [x for x in range(pad_min, pad_max + 1)]
634 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
635 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700636
Les Bell7ffccce2021-07-28 15:37:02 +0100637 for paddings in shape_pad_values:
638 name = "pad"
639 for r in range(rank):
640 before, after = paddings[r]
641 name = f"{name}{before}{after}"
642 arg_list.append((name, [np.array(paddings)]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700643
644 return arg_list
645
646 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100647 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700648 arg_list = []
649
650 shape = shapeList[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800651 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700652
Les Bell7aa69f42021-09-20 10:44:07 +0100653 # Generate comprehensive argument lists
654 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
655 paddings = {x for x in itertools.product(*([p_vals] * 4))}
656 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
657 strides = {x for x in itertools.product(*([s_vals] * 2))}
658 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
659 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700660
Les Bell7aa69f42021-09-20 10:44:07 +0100661 # add some oversize argument values
662 bigStride = 7
663 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
664 bigKernel = 6
665 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
666 if max(shape) < 64:
667 # padding must be less than the kernel size
668 bigPadding = bigKernel - 1
669 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700670
Les Bell7aa69f42021-09-20 10:44:07 +0100671 # There are too many parameter combinations, so generate them sparsely
672 sparsity = len(paddings) * len(strides) * len(kernels) // 500 + 1
673 n = 0
674 for s in sorted(list(strides)):
675 for p in sorted(list(paddings)):
676 for k in sorted(list(kernels)):
677 if (n % sparsity == 0
678 # padding must not exceed the kernel size
679 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
680 # the padded shape must exceed the kernel size
681 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
682 ):
683 arg_list.append(
684 (
685 "st{}_kern{}_pad{}".format(
686 "".join([str(x) for x in s]),
687 "".join([str(x) for x in k]),
688 "".join([str(x) for x in p]),
689 ),
690 [s, p, k],
691 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800692 )
Les Bell7aa69f42021-09-20 10:44:07 +0100693 n += 1
694
Eric Kunzee5e26762020-10-13 16:11:07 -0700695 return arg_list
696
697 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100698 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700699 arg_list = []
700
701 # Enumerate the output types here
702 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800703 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700704 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800705 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700706 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800707 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700708 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800709 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700710 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800711 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700712 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800713 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700714
715 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800716 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700717
718 return arg_list
719
720 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100721 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700722 arg_list = []
723
724 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100725 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
726 if inDtype == DType.UINT8 and dtype != DType.INT8:
727 # The only output dtype for UINT8 is INT8, skip all other combinations
728 continue
729 if inDtype != DType.INT8 and dtype == DType.UINT8:
730 # The only input dtype for UINT8 is INT8, skip all other combinations
731 continue
732
Kevin Cheng550ccc52021-03-03 11:21:43 -0800733 for scale32 in [False, True]:
734 for double_round in [False, True]:
735 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700736
737 if inDtype == DType.INT48 and scale32:
738 # Illegal condition. Must be scale32=False
739 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100740 if double_round and not scale32:
741 # Illegal condition. ERROR_IF(!scale32 && double_round)
742 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700743
Kevin Cheng550ccc52021-03-03 11:21:43 -0800744 arg_list.append(
745 (
746 "out{}_sc{}_dr{}_pc{}".format(
747 DTypeNames[dtype],
748 int(scale32),
749 int(double_round),
750 int(per_channel),
751 ),
752 [dtype, scale32, double_round, per_channel],
753 )
754 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700755
756 return arg_list
757
Kevin Chengaee1fac2020-11-11 13:54:06 -0800758 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100759 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800760 arg_list = []
761
762 if dtype is DType.INT32:
763 for p in range(testGen.args.num_rand_permutations):
764
765 shift = testGen.randInt(0, 32)
766
Kevin Cheng550ccc52021-03-03 11:21:43 -0800767 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800768 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100769 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800770
771 return arg_list
772
773 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100774 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -0800775 arg_list = []
776
Kevin Cheng550ccc52021-03-03 11:21:43 -0800777 arg_list.append(("roundTrue", [True]))
778 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800779
780 return arg_list
781
Eric Kunzee5e26762020-10-13 16:11:07 -0700782 # Helper function for reshape. Gets some factors of a larger number.
783 @staticmethod
784 def getFactors(val, start=1):
785 factors = []
786
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100787 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700788 if (val % i) == 0:
789 factors.append(i)
790
791 return factors
792
793 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100794 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700795 arg_list = []
796
797 origShape = shapeList[0]
798
799 totalElements = 1
800 for s in origShape:
801 totalElements *= s
802
803 # This code is NOT fast. Fortunately, the numbers are fairly small.
804 factors = TosaArgGen.getFactors(totalElements)
805
806 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100807 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800808 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700809 continue
810
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100811 found = True
812 # escape_counter breaks while loop if it continues on for too long
813 escape_counter = 0
814 while found:
815 newShape = []
816 # Generate newShape ensuring it isn't a duplicate
817 remainingElements = totalElements
818 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100819 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100820 # pick rank-1 factors
821 newShape.append(shuffledFactors[0])
822 remainingElements = remainingElements // shuffledFactors[0]
823 shuffledFactors = testGen.rng.permutation(
824 TosaArgGen.getFactors(remainingElements)
825 )
826 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700827
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100828 # Toss in a -1 sometimes
829 minusOne = testGen.randInt(0, newRank * 4)
830 if minusOne < newRank:
831 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700832
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100833 # Check for duplicates
834 found = False
835 for name, other_shape in arg_list:
836 if other_shape[0] == newShape:
837 found = True
838 break
839
840 escape_counter += 1
841 if escape_counter >= 100:
842 break
843
844 if not found:
845 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700846
847 return arg_list
848
Eric Kunzee5e26762020-10-13 16:11:07 -0700849 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100850 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700851 arg_list = []
852
853 ifm_shape = shapeList[0]
854
Jeremy Johnsona6185572021-06-21 15:55:35 +0100855 # Get all permutations
856 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700857
Jeremy Johnsona6185572021-06-21 15:55:35 +0100858 # Limit to possible permutations from shape dimension or argument setting
859 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700860
Jeremy Johnsona6185572021-06-21 15:55:35 +0100861 # Get random permutation generator that uses all permutations
862 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700863
Jeremy Johnsona6185572021-06-21 15:55:35 +0100864 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -0700865 arg_list = [
866 ("perm{}".format(p), [random_permutations[p].tolist()])
867 for p in range(limit)
868 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700869 return arg_list
870
871 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100872 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700873 arg_list = []
874
875 ifm_shape = shapeList[0]
876 rank = len(ifm_shape)
877
878 for p in range(testGen.args.num_rand_permutations):
879 begin = []
880 size = []
881
Kevin Cheng550ccc52021-03-03 11:21:43 -0800882 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700883
884 for i in range(rank):
885 if ifm_shape[i] > 1:
886 begin.append(testGen.randInt(0, ifm_shape[i]))
887 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
888
889 # Invalid slice size?
890 if size[i] == 0:
891 valid = False
892 else:
893 begin.append(0)
894 size.append(1)
895
896 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800897 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700898 return arg_list
899
900 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100901 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700902 arg_list = []
903
904 ifm_shape = shapeList[0]
905 rank = len(ifm_shape)
906
907 for p in range(testGen.args.num_rand_permutations):
908
909 # Pick a few random, but small multiple values
910 # because otherwise this has a tendency to generate
911 # enormous tensors
912 multiples = []
913 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +0100914 if ifm_shape[i] > 1000:
915 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
916 multiples.append(1)
917 elif max(ifm_shape) > 1000:
918 multiples.append(2)
919 else:
920 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800921 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700922
923 return arg_list
924
925 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100926 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700927 arg_list = []
928
929 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +0100930 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700931
932 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +0100933 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100934 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +0100935 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800936 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +0100937 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100938 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +0100939 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800940 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800941 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800942 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +0100943 elif error_name == ErrorIf.WrongInputType:
944 # If an incorrect input type is used then we set a 'correct'
945 # output type to avoid other errors
946 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700947 else:
948 continue
949
950 for outputDType in outputDTypeList:
951 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -0700952 # Randomly generate legal output dimensions and shift
953 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +0100954 # A output_dim of 1 will cause offset to exceed allowed range
955 # so minimum value 2 produced below
956 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
957 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
958 output_dims[0] += 1
959 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
960 output_dims[1] += 1
961
Kevin Cheng77d0f762020-11-24 10:26:32 -0800962 in_center_h = (ifm_shape[1] - 1) / 2.0
963 in_center_w = (ifm_shape[2] - 1) / 2.0
964 out_center_h = (output_dims[0] - 1) / 2.0
965 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700966
Kevin Cheng77d0f762020-11-24 10:26:32 -0800967 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
968 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
969 fp_offset_y = in_center_h - fp_stride_y * out_center_h
970 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -0700971
Kevin Cheng77d0f762020-11-24 10:26:32 -0800972 if outputDType == DType.FLOAT:
973 shift = 0
974 stride = [0, 0]
975 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800976 stride_fp = [fp_stride_y, fp_stride_x]
977 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +0100978
979 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +0100980 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +0100981 testGen,
982 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +0100983 mode,
984 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +0100985 shapeList,
986 outputDType,
987 shift,
988 stride,
989 stride_fp,
990 offset,
991 offset_fp
992 )
Matthew Haddon848efb42021-09-09 12:30:53 +0100993 else:
994 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +0100995
Kevin Cheng550ccc52021-03-03 11:21:43 -0800996 arg_list.append(
997 (
998 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +0100999 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001000 output_dims[0],
1001 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001002 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001003 stride_fp[0],
1004 stride_fp[1],
1005 offset_fp[0],
1006 offset_fp[1],
1007 ),
1008 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001009 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001010 stride,
1011 offset,
1012 shift,
1013 stride_fp,
1014 offset_fp,
1015 output_dims,
1016 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001017 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001018 ],
1019 )
1020 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001021 else:
1022 shift = 11
1023 unit = float(1 << shift)
1024 stride_y = int(round(fp_stride_y * unit))
1025 stride_x = int(round(fp_stride_x * unit))
1026 offset_y = int(round(fp_offset_y * unit))
1027 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001028
Kevin Cheng550ccc52021-03-03 11:21:43 -08001029 while (
Matthew Haddone86fd342021-09-07 16:12:21 +01001030 stride_y >= (16 << shift)
1031 or stride_x >= (16 << shift)
1032 or offset_y >= (16 << shift)
1033 or offset_x >= (16 << shift)
1034 or offset_y <= (-16 << shift)
1035 or offset_x <= (-16 << shift)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001036 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001037 shift = shift - 1
1038 unit = float(1 << shift)
1039 stride_y = int(round(fp_stride_y * unit))
1040 stride_x = int(round(fp_stride_x * unit))
1041 offset_y = int(round(fp_offset_y * unit))
1042 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001043
Kevin Cheng550ccc52021-03-03 11:21:43 -08001044 stride = [stride_y, stride_x]
1045 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001046
1047 stride_fp = [0.0, 0.0]
1048 offset_fp = [0.0, 0.0]
1049
Matthew Haddone86fd342021-09-07 16:12:21 +01001050 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001051 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001052 testGen,
1053 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001054 mode,
1055 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001056 shapeList,
1057 outputDType,
1058 shift,
1059 stride,
1060 stride_fp,
1061 offset,
1062 offset_fp
1063 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001064 else:
1065 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001066
Kevin Cheng550ccc52021-03-03 11:21:43 -08001067 arg_list.append(
1068 (
1069 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001070 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001071 shift,
1072 output_dims[0],
1073 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001074 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001075 stride[0],
1076 stride[1],
1077 offset[0],
1078 offset[1],
1079 ),
1080 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001081 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001082 stride,
1083 offset,
1084 shift,
1085 stride_fp,
1086 offset_fp,
1087 output_dims,
1088 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001089 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001090 ],
1091 )
1092 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001093
1094 return arg_list
1095
Matthew Haddon1c00b712021-10-01 15:51:03 +01001096 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001097 # CondIf generates the condition values here.
1098 # Convert to tensors in the build function, along with the
1099 # then and else blocks
1100 arg_list = []
1101
1102 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001103 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001104
1105 return arg_list
1106
Matthew Haddon1c00b712021-10-01 15:51:03 +01001107 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001108 # While loop: 0 iterations, 1, more than 1
1109 arg_list = []
1110
1111 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001112 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001113
1114 return arg_list
1115
Matthew Haddone86fd342021-09-07 16:12:21 +01001116class TosaErrorIfArgGen:
1117
1118 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001119 def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
Matthew Haddone86fd342021-09-07 16:12:21 +01001120
1121 if outputDType == DType.FLOAT:
1122 if error_name == ErrorIf.StrideSmallerEqualZero:
1123 stride_fp = testGen.rng.random(size=[2]) - 2
1124 elif error_name == ErrorIf.ShiftNotZero:
1125 shift = testGen.rng.integers(1, 5)
1126 elif error_name == ErrorIf.StrideLargerDimension:
1127 shape = shapeList[0]
1128 transform_height = testGen.rng.choice([False, True])
1129 if transform_height:
1130 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1131 else:
1132 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1133 else:
1134 if error_name == ErrorIf.StrideSmallerEqualZero:
1135 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1136 elif error_name == ErrorIf.ShiftSmallerOne:
1137 shift = testGen.rng.integers(-3, 1)
1138 if shift <= 0:
1139 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1140 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1141 else:
1142 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1143 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1144 elif error_name == ErrorIf.ShiftLargerEleven:
1145 shift = np.int16(testGen.rng.integers(12, 15))
1146 elif error_name == ErrorIf.StrideLargerDimension:
1147 shape = shapeList[0]
1148 transform_height = testGen.rng.choice([False, True])
1149 if transform_height:
1150 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1151 else:
1152 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1153 elif error_name == ErrorIf.StrideLargerEqualMax:
1154 stride = [(16 << shift) + 1, (16 << shift) + 1]
1155 elif error_name == ErrorIf.OffsetLargerEqualMax:
1156 offset = [(16 << shift) + 1, (16 << shift) + 1]
1157 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1158 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1159
Matthew Haddon1c00b712021-10-01 15:51:03 +01001160
Matthew Haddon848efb42021-09-09 12:30:53 +01001161 if error_name == ErrorIf.WrongOutputType:
1162 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1163 incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1164 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1165 incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
1166 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1167 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
1168 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1169 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
1170 elif dtype == DType.FLOAT:
1171 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
1172 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001173
Matthew Haddon848efb42021-09-09 12:30:53 +01001174 return shift, stride, stride_fp, offset, offset_fp, outputDType
1175
1176 @staticmethod
1177 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1178 # Mess up input/output tensors for ERROR_IF checks
1179 if error_name == "WrongInputList":
1180 add_input = testGen.rng.choice([True, False])
1181 if add_input:
1182 input_list.append('eiDummyInput')
1183 else:
1184 input_list = input_list[:-1]
1185 if error_name == "WrongOutputList":
1186 add_output = testGen.rng.choice([True, False])
1187 if add_output:
1188 output_list.append('eiDummyOutput')
1189 else:
1190 output_list = []
1191 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001192
1193class TosaErrorValidator:
1194
Matthew Haddon848efb42021-09-09 12:30:53 +01001195 @staticmethod
1196 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
1197 # Check ERROR_IF statements
1198
1199 for val_fcn in validator_fcns:
1200 val_result = val_fcn(True, **kwargs)
1201
1202 validator_name = val_result['error_name']
1203 error_result = val_result['error_result']
1204 error_reason = val_result['error_reason']
1205
1206 if error_result:
1207 if error_name == validator_name:
1208 serializer.setExpectedReturnCode(2, error_reason)
1209 else:
1210 print(f"Multiple ERROR_IF checks hit \nError required: {error_name}, Error_produced: {validator_name}")
1211 return None # Return None to delete test if wrong ERROR_IF is hit
1212 else:
1213 if error_name == validator_name:
1214 print(f"No ERROR_IF hit for {error_name}")
1215 return None
1216
1217 @staticmethod
1218 def evWrongInputType(check=False, **kwargs):
1219 all_dtypes = (DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1220
1221 # Find the unsupported input data types
1222 assert 'op' in kwargs
1223 op = kwargs['op']
1224 input_dtypes = op['types']
1225 wrong_input_dtypes = list(set(all_dtypes) - set(input_dtypes))
1226
1227 error_name = ErrorIf.WrongInputType
1228 param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
1229 error_result = False
1230 error_reason = "Input data type not supported for this operator"
1231
1232 if check:
1233 input_dtype = kwargs['input_dtype']
1234 if input_dtype not in input_dtypes:
1235 error_result = True
1236
1237 info_dict = {
1238 "error_name": error_name,
1239 "error_result": error_result,
1240 "error_reason": error_reason,
1241 "param_reqs": param_reqs
1242 }
1243 return info_dict
1244
1245 @staticmethod
1246 def evWrongOutputType(check=False, **kwargs):
1247 error_name = ErrorIf.WrongOutputType
1248 param_reqs = {"rank": None, "dtype": None, "shape": None}
1249 error_result = False
1250 error_reason = "Output data type not supported for this configuration of operator"
1251
1252 if check:
1253 input_dtype = kwargs['input_dtype']
1254 output_dtype = kwargs['output_dtype']
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001255 op = kwargs['op']
Matthew Haddon848efb42021-09-09 12:30:53 +01001256
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001257 if op['op'] == Op.RESIZE:
1258 mode = kwargs['mode']
1259 if (
1260 (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
1261 (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
1262 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1263 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1264 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1265 ):
1266 error_result = True
1267 else:
1268 if output_dtype != input_dtype:
1269 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001270
1271 info_dict = {
1272 "error_name": error_name,
1273 "error_result": error_result,
1274 "error_reason": error_reason,
1275 "param_reqs": param_reqs
1276 }
1277 return info_dict
1278
1279 @staticmethod
1280 def evWrongRank(check=False, **kwargs):
1281 all_ranks = (1, 2, 3, 4, 5)
1282
1283 # Make a list of incorrect ranks
1284 assert 'op' in kwargs
1285 op = kwargs['op']
1286 rmin, rmax = op['rank']
1287 rank_range = range(rmin, rmax + 1)
1288 incorrect_ranks = list(set(all_ranks) - set(rank_range))
1289 # Set minimum incorrect rank to 3 to avoid index error
1290 if op['op'] == Op.RESIZE:
1291 incorrect_ranks = [3, 5]
1292
1293 error_name = ErrorIf.WrongRank
1294 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1295 error_result = False
1296 error_reason = "Rank not supported for this operator"
1297
1298 if check:
1299 input_shape = kwargs['input_shape']
1300 if op['op'] == Op.RESIZE and len(input_shape.shape) != 4:
1301 error_result = True
1302
1303 info_dict = {
1304 "error_name": error_name,
1305 "error_result": error_result,
1306 "error_reason": error_reason,
1307 "param_reqs": param_reqs
1308 }
1309 return info_dict
1310
1311 @staticmethod
1312 def evWrongInputList(check=False, **kwargs):
1313 error_name = ErrorIf.WrongInputList
1314 param_reqs = {"rank": None, "dtype": None, "shape": None}
1315 error_result = False
1316 error_reason = "Op input list does not match expected input"
1317
1318 if check:
1319 op = kwargs['op']
1320 input_list = kwargs['input_list']
1321 num_operands = kwargs['num_operands']
1322 if len(input_list) != num_operands:
1323 error_result = True
1324
1325 info_dict = {
1326 "error_name": error_name,
1327 "error_result": error_result,
1328 "error_reason": error_reason,
1329 "param_reqs": param_reqs
1330 }
1331 return info_dict
1332
1333 @staticmethod
1334 def evWrongOutputList(check=False, **kwargs):
1335 error_name = ErrorIf.WrongOutputList
1336 param_reqs = {"rank": None, "dtype": None, "shape": None}
1337 error_result = False
1338 error_reason = "Op output list does not match expected output"
1339
1340 if check:
1341 output_list = kwargs['output_list']
1342 # Note this will be incorrect if an operator returns more than one output
1343 if len(output_list) != 1:
1344 error_result = True
1345
1346 info_dict = {
1347 "error_name": error_name,
1348 "error_result": error_result,
1349 "error_reason": error_reason,
1350 "param_reqs": param_reqs
1351 }
1352 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01001353
1354 @staticmethod
1355 def evMaxDimExceeded(check=False, **kwargs):
1356 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01001357 param_reqs = {
1358 "rank": [4,4],
1359 "dtype": [DType.INT8],
1360 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
1361 }
Matthew Haddone86fd342021-09-07 16:12:21 +01001362 error_result = False
1363 error_reason = "At least one maximum dimension is larger than 16384"
1364
1365 if check:
1366 input_shape = kwargs['input_shape'].shape
1367 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
1368 if ((input_shape[1] > 16384) or
1369 (input_shape[2] > 16384) or
1370 (output_shape[0] > 16384) or
1371 (output_shape[1] > 16384)):
1372 error_result = True
1373
1374 info_dict = {
1375 "error_name": error_name,
1376 "error_result": error_result,
1377 "error_reason": error_reason,
1378 "param_reqs": param_reqs
1379 }
1380 return info_dict
1381
1382 @staticmethod
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001383 def evBatchMismatch(check=False, **kwargs):
1384 error_name = ErrorIf.BatchMismatch
1385 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1386 error_result = False
1387 error_reason = "Input batch size not equal to output batch size"
1388
1389 assert 'op' in kwargs
1390 op = kwargs['op']
1391 rmin, rmax = op['rank']
1392 rank_range = range(rmin, rmax + 1)
1393
1394 if check:
1395 input_shape = kwargs['input_shape'].shape
1396 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1397
1398 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
1399 error_result = True
1400
1401 info_dict = {
1402 "error_name": error_name,
1403 "error_result": error_result,
1404 "error_reason": error_reason,
1405 "param_reqs": param_reqs
1406 }
1407 return info_dict
1408
1409 @staticmethod
1410 def evChannelMismatch(check=False, **kwargs):
1411 error_name = ErrorIf.ChannelMismatch
1412 param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
1413 error_result = False
1414 error_reason = "Input channel size not equal to output channel size"
1415
1416 assert 'op' in kwargs
1417 op = kwargs['op']
1418 rmin, rmax = op['rank']
1419 rank_range = range(rmin, rmax + 1)
1420
1421 if check:
1422 input_shape = kwargs['input_shape'].shape
1423 output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
1424 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
1425 error_result = True
1426
1427 info_dict = {
1428 "error_name": error_name,
1429 "error_result": error_result,
1430 "error_reason": error_reason,
1431 "param_reqs": param_reqs
1432 }
1433 return info_dict
1434
1435 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001436 def evStrideSmallerEqualZero(check=False, **kwargs):
1437 error_name = ErrorIf.StrideSmallerEqualZero
1438 param_reqs = {"rank": None, "dtype": None, "shape": None}
1439 error_result = False
1440 error_reason = "Stride value smaller than or equal zero"
1441
1442 if check:
1443 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001444 output_dtype = kwargs['output_dtype']
1445 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
1446 stride = kwargs['stride'] # Work around wrong input/output type tests
1447 elif output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001448 stride = kwargs['stride_fp']
Matthew Haddon848efb42021-09-09 12:30:53 +01001449 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1450 stride = kwargs['stride_fp'] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01001451 else:
1452 stride = kwargs['stride']
1453
1454 if min(stride) <= 0:
1455 error_result = True
1456
1457 info_dict = {
1458 "error_name": error_name,
1459 "error_result": error_result,
1460 "error_reason": error_reason,
1461 "param_reqs": param_reqs
1462 }
1463 return info_dict
1464
1465 @staticmethod
1466 def evStrideLargerEqualMax(check=False, **kwargs):
1467 error_name = ErrorIf.StrideLargerEqualMax
1468 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1469 error_result = False
1470 error_reason = "Stride value larger than or equal to maximum value"
1471
1472 if check:
1473 shift = kwargs['shift']
1474 input_dtype = kwargs['input_dtype']
1475 stride = kwargs['stride']
1476 if input_dtype in [DType.INT8, DType.INT16]:
1477 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1478 error_result = True
1479 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1480 error_result = True
1481
1482 info_dict = {
1483 "error_name": error_name,
1484 "error_result": error_result,
1485 "error_reason": error_reason,
1486 "param_reqs": param_reqs
1487 }
1488 return info_dict
1489
1490
1491 @staticmethod
1492 def evStrideLargerDimension(check=False, **kwargs):
1493 error_name = ErrorIf.StrideLargerDimension
1494 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1495 error_result = False
1496 error_reason = "Stride value larger than or equal to H/W dimension"
1497
1498 if check:
1499 shape = kwargs['input_shape'].shape
1500 input_dtype = kwargs['input_dtype']
1501 stride = kwargs['stride_fp']
1502
1503 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1504 error_result = True
1505
1506 info_dict = {
1507 "error_name": error_name,
1508 "error_result": error_result,
1509 "error_reason": error_reason,
1510 "param_reqs": param_reqs
1511 }
1512 return info_dict
1513
1514
1515 @staticmethod
1516 def evOffsetSmallerEqualMin(check=False, **kwargs):
1517 error_name = ErrorIf.OffsetSmallerEqualMin
1518 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1519 error_result = False
1520 error_reason = "Offset value smaller than or equal to minimum value"
1521
1522 if check:
1523 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001524 output_dtype = kwargs['output_dtype']
1525 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001526 offset = kwargs['offset_fp']
1527 else:
1528 offset = kwargs['offset']
1529
1530 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
1531 error_result = True
1532 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
1533 error_result = True
1534
1535 info_dict = {
1536 "error_name": error_name,
1537 "error_result": error_result,
1538 "error_reason": error_reason,
1539 "param_reqs": param_reqs
1540 }
1541 return info_dict
1542
1543 @staticmethod
1544 def evOffsetLargerEqualMax(check=False, **kwargs):
1545 error_name = ErrorIf.OffsetLargerEqualMax
1546 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1547 error_result = False
1548 error_reason = "Offset value larger than or equal to maximum value"
1549
1550 if check:
1551 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001552 output_dtype = kwargs['output_dtype']
1553 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001554 offset = kwargs['offset_fp']
1555 else:
1556 offset = kwargs['offset']
1557
1558 if shift >= 0:
1559 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
1560 error_result = True
1561
1562 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
1563 error_result = True
1564 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
1565 error_result = True
1566
1567 info_dict = {
1568 "error_name": error_name,
1569 "error_result": error_result,
1570 "error_reason": error_reason,
1571 "param_reqs": param_reqs
1572 }
1573 return info_dict
1574
1575 @staticmethod
1576 def evShiftNotZero(check=False, **kwargs):
1577 error_name = ErrorIf.ShiftNotZero
1578 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1579 error_result = False
1580 error_reason = "Shift value must be zero for float input"
1581
1582 if check:
1583 shift = kwargs['shift']
1584 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001585 output_dtype = kwargs['output_dtype']
1586 if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
Matthew Haddone86fd342021-09-07 16:12:21 +01001587 error_result = True
1588
1589 info_dict = {
1590 "error_name": error_name,
1591 "error_result": error_result,
1592 "error_reason": error_reason,
1593 "param_reqs": param_reqs
1594 }
1595 return info_dict
1596
1597
1598 @staticmethod
1599 def evShiftSmallerOne(check=False, **kwargs):
1600 error_name = ErrorIf.ShiftSmallerOne
1601 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1602 error_result = False
1603 error_reason = "Shift value smaller than one"
1604
1605 if check:
1606 shift = kwargs['shift']
1607 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001608 output_dtype = kwargs['output_dtype']
1609 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001610 error_result = True
1611
1612 info_dict = {
1613 "error_name": error_name,
1614 "error_result": error_result,
1615 "error_reason": error_reason,
1616 "param_reqs": param_reqs
1617 }
1618 return info_dict
1619
1620 @staticmethod
1621 def evShiftLargerEleven(check=False, **kwargs):
1622 error_name = ErrorIf.ShiftLargerEleven
1623 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1624 error_result = False
1625 error_reason = "Shift value larger than eleven"
1626
1627 if check:
1628 shift = kwargs['shift']
1629 if shift > 11:
1630 error_result = True
1631
1632 info_dict = {
1633 "error_name": error_name,
1634 "error_result": error_result,
1635 "error_reason": error_reason,
1636 "param_reqs": param_reqs
1637 }
1638 return info_dict
1639
1640
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001641 @staticmethod
1642 def evRankMismatch(check=False, **kwargs):
1643 error_name = ErrorIf.RankMismatch
1644 param_reqs = {"rank": None, "dtype": None, "shape": None}
1645 error_result = False
1646 error_reason = "Input Rank does not match output rank"
1647
1648 if check:
1649 input1_shape = kwargs['input1'].shape
1650 input2_shape = kwargs['input2'].shape
1651 output_shape = kwargs['result_tensor'].shape
1652 if (len(input1_shape) != len(output_shape)) or (len(input2_shape) != len(output_shape)):
1653 error_result = True
1654
1655 info_dict = {
1656 "error_name": error_name,
1657 "error_result": error_result,
1658 "error_reason": error_reason,
1659 "param_reqs": param_reqs
1660 }
1661 return info_dict
1662
Matthew Haddone4ecdb22021-09-28 11:38:21 +01001663 @staticmethod
1664 def evInputZeroPointNotZero(check=False, **kwargs):
1665 error_name = ErrorIf.InputZeroPointNotZero
1666 param_reqs = {
1667 "rank": None,
1668 "dtype": [DType.INT16, DType.INT32, DType.FLOAT],
1669 "shape": None
1670 }
1671 error_result = False
1672 error_reason = "Input DType not INT8 and zero point not 0"
1673
1674 if check:
1675 input_dtype = kwargs['input_dtype']
1676 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1677 qinfo = kwargs['qinfo'].ints
1678 input_zero_point = qinfo[0][1]
1679 if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
1680 error_result = True
1681
1682 info_dict = {
1683 "error_name": error_name,
1684 "error_result": error_result,
1685 "error_reason": error_reason,
1686 "param_reqs": param_reqs
1687 }
1688 return info_dict
1689
1690
1691 @staticmethod
1692 def evOutputZeroPointNotZero(check=False, **kwargs):
1693 error_name = ErrorIf.OutputZeroPointNotZero
1694 param_reqs = {
1695 "rank": None,
1696 "dtype": [DType.INT16, DType.INT32, DType.FLOAT],
1697 "shape": None
1698 }
1699 error_result = False
1700 error_reason = "Output DType not INT8 and zero point not 0"
1701
1702 if check:
1703 output_dtype = kwargs['output_dtype']
1704 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1705 qinfo = kwargs['qinfo'].ints
1706 output_zero_point = qinfo[1][1]
1707 if output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
1708 error_result = True
1709
1710 info_dict = {
1711 "error_name": error_name,
1712 "error_result": error_result,
1713 "error_reason": error_reason,
1714 "param_reqs": param_reqs
1715 }
1716 return info_dict
1717
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001718
Matthew Haddonb724efc2021-08-25 16:40:29 +01001719class TosaInvalidValidator:
1720
1721 @staticmethod
1722 def ivWrongDataTypeOrModeResize(**kwargs):
1723 input_dtype = kwargs["input_dtype"]
1724 args = kwargs["args"]
1725 mode = args[0]
1726 stride = args[1]
1727 stride_fp = args[4]
1728 output_dtype = args[8]
1729
1730 if mode == ResizeMode.BILINEAR:
1731 # Invalid output data type / Invalid input datatype
1732 return (
1733 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
1734 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
1735 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
1736 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
1737 )
1738 elif mode == ResizeMode.NEAREST:
1739 # Invalid output data type / Invalid input datatype
1740 return (
1741 (input_dtype != output_dtype) or
1742 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
1743 )
1744 else:
1745 # Invalid resize mode
1746 return True
1747
1748 @staticmethod
1749 def ivBadStride(**kwargs):
1750 input_dtype = kwargs["input_dtype"]
1751 args = kwargs["args"]
1752 stride_x = args[1][0]
1753 stride_y = args[1][1]
1754 stride_fp_x = args[4][0]
1755 stride_fp_y = args[4][1]
1756
1757 if input_dtype == DType.FLOAT:
1758 if stride_fp_x <= 0 or stride_fp_y <= 0:
1759 # Negative or zero stride
1760 return True
1761 else:
1762 if stride_x <= 0 or stride_y <= 0:
1763 # Negative or zero stride
1764 return True
1765 return False
1766
1767
Matthew Haddonb724efc2021-08-25 16:40:29 +01001768 @staticmethod
1769 def ivHeightWidthSmallerZero(**kwargs):
1770 opName = kwargs['opName']
1771
1772 inputShapes = kwargs['shapeList']
1773 input = inputShapes[0]
1774 if not opName.endswith("pool2d"):
1775 filter = inputShapes[1]
1776
1777 args = kwargs['args']
1778 strides = args[0]
1779 padding = args[1]
1780 dilations = args[2]
1781 if opName.endswith("pool2d"):
1782 kernel = args[2]
1783
1784 if opName.startswith('conv2d'):
1785 h = (
1786 input[1]
1787 - filter[1]
1788 - (filter[1] - 1) * (dilations[0] - 1)
1789 + padding[0]
1790 + padding[1]
1791 ) // strides[0] + 1
1792
1793 w = (
1794 input[2]
1795 - filter[2]
1796 - (filter[2] - 1) * (dilations[1] - 1)
1797 + padding[2]
1798 + padding[3]
1799 ) // strides[1] + 1
1800 elif opName.startswith("depthwise_conv2d"):
1801 h = (
1802 input[1]
1803 - filter[0]
1804 - (filter[0] - 1) * (dilations[0] - 1)
1805 + padding[0]
1806 + padding[1]
1807 ) // strides[0] + 1
1808
1809 w = (
1810 input[2]
1811 - filter[1]
1812 - (filter[1] - 1) * (dilations[1] - 1)
1813 + padding[2]
1814 + padding[3]
1815 ) // strides[1] + 1
1816 elif opName.endswith("pool2d"):
1817 h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
1818 w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
1819 else:
1820 assert False, "Unrecognized Op"
1821
1822 if h <= 0 or w <= 0:
1823 # Invalid parameter combination
1824 return True
1825 return False
1826
1827 @staticmethod
1828 def ivNonPositiveOutputShape(**kwargs):
1829 args = kwargs['args']
1830 output_shape = args[3]
1831 if output_shape[1] <= 0 or output_shape[2] <= 0:
1832 # Negative output shape
1833 return True
1834 return False
1835
1836
Kevin Cheng550ccc52021-03-03 11:21:43 -08001837
Eric Kunzee5e26762020-10-13 16:11:07 -07001838class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001839 # Maximum rank of tensor supported by test generator.
1840 TOSA_TENSOR_MAX_RANK = 6
1841
Eric Kunzee5e26762020-10-13 16:11:07 -07001842 def __init__(self, args):
1843 self.args = args
1844 self.basePath = args.output_dir
1845 self.random_seed = args.random_seed
1846 self.ser = None
1847 self.rng = np.random.default_rng(self.random_seed)
1848 self.createDynamicOpLists()
1849 self.initOpListDefaults()
1850 self.quantGen = TosaQuantGen()
1851 # Force makeShape to do a specific starting shape
1852 self.targetted_shape = None
1853
1854 def createSerializer(self, opName, testPath):
1855 self.testPath = os.path.join(opName, testPath)
1856
1857 fullPath = os.path.join(self.basePath, self.testPath)
1858 os.makedirs(fullPath, exist_ok=True)
1859 self.ser = ts.TosaSerializer(fullPath)
1860
1861 def getSerializer(self):
1862 return self.ser
1863
1864 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001865 with open(
1866 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
1867 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07001868 fd.write(self.ser.serialize())
1869
Kevin Cheng550ccc52021-03-03 11:21:43 -08001870 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
1871 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07001872
Matthew Haddon74567092021-07-16 15:38:20 +01001873 def resetRNG(self, seed=None):
1874 if seed == None:
1875 seed = self.random_seed + 1
1876 self.rng = np.random.default_rng(seed)
1877
Eric Kunzee5e26762020-10-13 16:11:07 -07001878 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07001879 if dtype == DType.BOOL:
1880 np_dt = np.bool
1881 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07001882 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07001883 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07001884 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001885 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001886 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
1887 elif dtype == DType.UINT8:
1888 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001889 elif dtype == DType.INT16:
1890 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
1891 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001892 return np.int32(
1893 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
1894 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001895 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001896 return np.int64(
1897 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
1898 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001899 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001900 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001901 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001902 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001903
Kevin Cheng989cb052021-04-28 16:29:44 -07001904 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07001905 placeholders = []
1906
Kevin Cheng989cb052021-04-28 16:29:44 -07001907 assert len(shape_list) == len(dtype_list)
1908
1909 for idx, shape in enumerate(shape_list):
1910 arr = self.getRandTensor(shape, dtype_list[idx])
1911 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001912
1913 return placeholders
1914
Kevin Cheng989cb052021-04-28 16:29:44 -07001915 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07001916 consts = []
1917
Kevin Cheng989cb052021-04-28 16:29:44 -07001918 assert len(shape_list) == len(dtype_list)
1919
1920 for idx, shape in enumerate(shape_list):
1921 arr = self.getRandTensor(shape, dtype_list[idx])
1922 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001923
1924 return consts
1925
1926 def makeShape(self, rank):
1927 if self.targetted_shape:
1928 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001929 return np.int32(
1930 self.rng.integers(
1931 low=self.args.tensor_shape_range[0],
1932 high=self.args.tensor_shape_range[1],
1933 size=rank,
1934 )
1935 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001936
1937 def setTargetShape(self, shape):
1938 self.targetted_shape = shape
1939
1940 def randInt(self, low=0, high=256):
1941 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
1942
1943 def getRandNumberDType(self, dtype):
1944 if dtype == DType.FLOAT:
1945 return self.rng.random()
1946 elif dtype == DType.BOOL:
1947 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07001948 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07001949 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07001950 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001951 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001952 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07001953 elif dtype == DType.INT16:
1954 low, high = (-32768, 32768)
1955 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001956 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07001957 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001958 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07001959 # Special size
1960 return np.int64(self.rng.integers(low, high, size=1))[0]
1961 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001962 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001963
1964 return np.int32(self.rng.integers(low, high, size=1))[0]
1965
1966 def shapeStr(self, shape):
1967
1968 sStr = []
1969 # Convert to strings
1970 for i in shape:
1971 sStr.append(str(i))
1972
Kevin Cheng550ccc52021-03-03 11:21:43 -08001973 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001974
1975 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07001976 if isinstance(t, list):
1977 assert len(t) >= 2
1978 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001979 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001980 if t == DType.BOOL:
1981 return "b"
1982 elif t == DType.INT4:
1983 return "i4"
1984 elif t == DType.INT8:
1985 return "i8"
1986 elif t == DType.UINT8:
1987 return "u8"
1988 elif t == DType.INT16:
1989 return "i16"
1990 elif t == DType.INT32:
1991 return "i32"
1992 elif t == DType.INT48:
1993 return "i48"
1994 elif t == DType.FLOAT:
1995 return "float"
1996 else:
1997 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001998
1999 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002000 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08002001 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07002002 return 4
2003 elif t == DType.INT8:
2004 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08002005 elif t == DType.UINT8:
2006 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07002007 elif t == DType.INT16:
2008 return 16
2009 elif t == DType.INT32:
2010 return 32
2011 elif t == DType.INT48:
2012 return 48
2013 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002014 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07002015
2016 # Argument generators
2017 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
2018 # Where the string descriptor is used to generate the test name and
2019 # The build_fcn_arg_list is expanded and passed to the operator test
2020 # build function
2021
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002022 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
2023 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
2024
Matthew Haddon848efb42021-09-09 12:30:53 +01002025 # build_placeholder returns an int, ABS/other ops does not
2026 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002027 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
2028 return result_tens
2029 elif op['op'] == Op.IDENTITY:
2030 self.ser.addOperator(op['op'], a.name, result_tens.name, None, qinfo)
2031 return result_tens
2032
2033 # Ensure new output type has correct qinfo
2034 if error_name == ErrorIf.WrongOutputType:
2035 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
2036 qinfo = ts.TosaSerializerQuantInfo()
2037 qinfo.UnaryQuantInfo(
2038 TosaQuantGen.getQinfo(self, a.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
2039 )
2040
2041 # Invalidate Input/Output list for error if checks.
2042 input_list = [a.name]
2043 output_list = [result_tens.name]
2044 pCount, cCount = op["operands"]
2045 num_operands = pCount + cCount
2046 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2047
2048 TosaErrorValidator.evValidateErrorIfs(
2049 self.ser,
2050 validator_fcns,
2051 error_name,
2052 op=op,
2053 input_dtype=a.dtype,
2054 output_dtype=result_tens.dtype,
2055 qinfo = qinfo,
2056 result_tensor = result_tens,
2057 input_list=input_list,
2058 output_list=output_list,
2059 num_operands=num_operands,
2060 )
2061
2062 self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002063 return result_tens
2064
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002065 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
2066 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
2067
2068
2069 # Invalidate Input/Output list for error if checks.
2070 input_list = [a.name, b.name]
2071 output_list = [result_tens.name]
2072 pCount, cCount = op["operands"]
2073 num_operands = pCount + cCount
2074 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
2075
2076 TosaErrorValidator.evValidateErrorIfs(
2077 self.ser,
2078 validator_fcns,
2079 error_name,
2080 op=op,
2081 input1 = a,
2082 input2 = b,
2083 input_dtype = a.dtype,
2084 output_dtype = result_tens.dtype,
2085 result_tensor = result_tens,
2086 input_list=input_list,
2087 output_list=output_list,
2088 num_operands=num_operands,
2089 )
2090
2091 self.ser.addOperator(op['op'], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07002092 return result_tens
2093
2094 def build_binary_nonbroadcast(self, op, a, b):
2095 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002096 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002097 return result_tens
2098
Kevin Chengaee1fac2020-11-11 13:54:06 -08002099 def build_arithmetic_right_shift(self, op, a, b, round):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002100 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002101
2102 attr = ts.TosaSerializerAttribute()
2103 attr.ArithmeticRightShiftAttribute(round)
2104
Matthew Haddon848efb42021-09-09 12:30:53 +01002105 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08002106 return result_tens
2107
2108 def build_mul(self, op, a, b, shift):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002109 result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
Eric Kunzee5e26762020-10-13 16:11:07 -07002110
2111 # Special for multiply:
2112 # Force the result to INT32 for INT types
2113 if a.dtype != DType.FLOAT:
2114 result_tens.setDtype(DType.INT32)
2115
Kevin Chengaee1fac2020-11-11 13:54:06 -08002116 attr = ts.TosaSerializerAttribute()
2117 attr.MulAttribute(shift)
2118
Matthew Haddon848efb42021-09-09 12:30:53 +01002119 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002120 return result_tens
2121
2122 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002123 # Constant size depending on type, random values
2124 if a.dtype == DType.INT16:
Kevin Chengacb550f2021-06-29 15:32:19 -07002125 table_dtype = DType.INT16
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002126 table_arr = self.getRandTensor([513], table_dtype)
2127 else:
2128 assert a.dtype == DType.INT8
2129 table_dtype = DType.INT8
2130 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002131
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002132 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
2133 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01002134 self.ser.addOperator(op['op'], [a.name, table_tens.name], [result_tens.name], None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002135
2136 return result_tens
2137
2138 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07002139 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002140 self.ser.addOperator(op['op'], [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002141 return result_tens
2142
2143 def build_comparison(self, op, a, b):
2144 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002145 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002146 return result_tens
2147
2148 def build_argmax(self, op, a, axis):
2149 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
2150
2151 attr = ts.TosaSerializerAttribute()
2152 attr.AxisAttribute(axis)
2153
Matthew Haddon848efb42021-09-09 12:30:53 +01002154 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002155 return result_tens
2156
Matthew Haddonb724efc2021-08-25 16:40:29 +01002157 def build_pool2d(self, op, input, stride, pad, kernel, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07002158 result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
2159
2160 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002161 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07002162
Matthew Haddon848efb42021-09-09 12:30:53 +01002163 self.ser.addOperator(op['op'], [input.name], [result_tens.name], attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002164 return result_tens
2165
2166 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002167 assert len(padding) == 4
2168 result_tens = OutputShaper.conv2dOp(
2169 self.ser, ifm, filter, strides, padding, dilations
2170 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002171
2172 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002173 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07002174
Kevin Cheng550ccc52021-03-03 11:21:43 -08002175 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002176 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002177 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002178 return result_tens
2179
Kevin Cheng1533b852021-09-01 12:51:58 -07002180 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
2181 assert len(padding) == 6
2182 result_tens = OutputShaper.conv3dOp(
2183 self.ser, ifm, filter, strides, padding, dilations
2184 )
2185
2186 attr = ts.TosaSerializerAttribute()
2187 attr.ConvAttribute(padding, strides, dilations)
2188
2189 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002190 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng1533b852021-09-01 12:51:58 -07002191 )
2192 return result_tens
2193
Kevin Cheng550ccc52021-03-03 11:21:43 -08002194 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07002195 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002196 ):
2197 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07002198 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
2199
2200 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002201 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002202
Kevin Cheng550ccc52021-03-03 11:21:43 -08002203 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002204 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002205 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002206 return result_tens
2207
Kevin Cheng550ccc52021-03-03 11:21:43 -08002208 def build_depthwise_conv2d(
2209 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
2210 ):
2211 result_tens = OutputShaper.depthwiseConv2dOp(
2212 self.ser, ifm, filter, strides, padding, dilations
2213 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002214
2215 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07002216 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07002217
Kevin Cheng550ccc52021-03-03 11:21:43 -08002218 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002219 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002220 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002221 return result_tens
2222
2223 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
2224 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
2225
Kevin Cheng550ccc52021-03-03 11:21:43 -08002226 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002227 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002228 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002229 return result_tens
2230
2231 def build_matmul(self, op, a, b, qinfo):
2232 result_tens = OutputShaper.matmulOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002233 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002234 return result_tens
2235
2236 def build_reduce(self, op, a, axis):
2237 result_tens = OutputShaper.reduceOp(self.ser, a, axis)
2238
2239 attr = ts.TosaSerializerAttribute()
2240 attr.AxisAttribute(axis)
2241
Matthew Haddon848efb42021-09-09 12:30:53 +01002242 self.ser.addOperator(op['op'], [a.name], result_tens.name, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002243 return result_tens
2244
2245 def build_clamp(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002246 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002247
2248 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01002249 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07002250
2251 if a.dtype == DType.FLOAT:
2252 attr.ClampAttribute(0, 0, min(v), max(v))
2253 else:
2254 attr.ClampAttribute(min(v), max(v), 0, 0)
2255
Matthew Haddon848efb42021-09-09 12:30:53 +01002256 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002257 return result_tens
2258
2259 def build_leaky_relu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002260 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002261 attr = ts.TosaSerializerAttribute()
2262
2263 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
2264
Matthew Haddon848efb42021-09-09 12:30:53 +01002265 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002266 return result_tens
2267
2268 # Needs an additional type/input
2269 def build_prelu(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002270 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002271
Matthew Haddon848efb42021-09-09 12:30:53 +01002272 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002273 return result_tens
2274
Eric Kunzee5e26762020-10-13 16:11:07 -07002275 def build_sigmoid(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002276 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01002277 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002278 return result_tens
2279
2280 def build_tanh(self, op, a):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002281 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Matthew Haddon848efb42021-09-09 12:30:53 +01002282 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002283 return result_tens
2284
Matthew Haddon818ab902021-07-27 09:12:49 +01002285 def build_concat(self, op, *a):
Kevin Cheng93a16282021-08-31 16:14:03 -07002286 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01002287
2288 # To store variable length list of input tensors we need to store axis along with it
2289 axis = a[-1]
2290 a = a[:-1]
2291
2292 result_tens = OutputShaper.concatOp(self.ser, axis, *a)
Eric Kunzee5e26762020-10-13 16:11:07 -07002293
2294 attr = ts.TosaSerializerAttribute()
2295 attr.AxisAttribute(axis)
2296
Matthew Haddon818ab902021-07-27 09:12:49 +01002297 input_tensor_names = []
2298 for tensor in a:
2299 input_tensor_names.append(tensor.name)
2300
Matthew Haddon848efb42021-09-09 12:30:53 +01002301 self.ser.addOperator(op['op'], input_tensor_names, [result_tens.name], attr)
2302 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002303
2304 def build_pad(self, op, a, padding, qinfo):
2305 result_tens = OutputShaper.padOp(self.ser, a, padding)
2306
2307 # Need to turn the padding array into a TOSA tensor here.
2308 # This is one of the few tensor operands that does not get
2309 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08002310 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07002311
Kevin Cheng550ccc52021-03-03 11:21:43 -08002312 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002313 op['op'], [a.name, padding_tens.name], [result_tens.name], None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002314 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002315 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002316
2317 def build_reshape(self, op, a, newShape):
2318 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
2319
2320 attr = ts.TosaSerializerAttribute()
2321 attr.ReshapeAttribute(newShape)
2322
Matthew Haddon848efb42021-09-09 12:30:53 +01002323 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002324 return result_tens
2325
2326 def build_reverse(self, op, a, axis):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002327 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07002328
2329 attr = ts.TosaSerializerAttribute()
2330 attr.AxisAttribute(axis)
2331
Matthew Haddon848efb42021-09-09 12:30:53 +01002332 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002333 return result_tens
2334
2335 def build_transpose(self, op, a, perms):
2336 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
2337
Kevin Cheng550ccc52021-03-03 11:21:43 -08002338 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07002339
Matthew Haddon848efb42021-09-09 12:30:53 +01002340 self.ser.addOperator(op['op'], [a.name, perms_tens.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002341 return result_tens
2342
2343 def build_slice(self, op, a, begin, size):
2344 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
2345
2346 attr = ts.TosaSerializerAttribute()
2347 attr.SliceAttribute(begin, size)
2348
Matthew Haddon848efb42021-09-09 12:30:53 +01002349 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002350 return result_tens
2351
2352 def build_tile(self, op, a, multiples):
2353 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
2354
2355 attr = ts.TosaSerializerAttribute()
2356 attr.TileAttribute(multiples)
2357
Matthew Haddon848efb42021-09-09 12:30:53 +01002358 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002359 return result_tens
2360
Kevin Cheng77d0f762020-11-24 10:26:32 -08002361 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07002362
2363 # Create a new indicies tensor
2364 # here with data that doesn't exceed the dimensions of the values tensor
2365
Kevin Cheng550ccc52021-03-03 11:21:43 -08002366 K = values.shape[1] # K
2367 W = self.randInt(
2368 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
2369 ) # W
2370 indicies_arr = np.int32(
2371 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
2372 ) # (N, W)
2373 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002374
Kevin Cheng77d0f762020-11-24 10:26:32 -08002375 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07002376
Matthew Haddon848efb42021-09-09 12:30:53 +01002377 self.ser.addOperator(op['op'], [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002378
2379 return result_tens
2380
Kevin Cheng77d0f762020-11-24 10:26:32 -08002381 def build_scatter(self, op, values_in, input):
2382
2383 # Create a new indicies tensor
2384 # here with data that doesn't exceed the dimensions of the values_in tensor
2385
Kevin Cheng550ccc52021-03-03 11:21:43 -08002386 K = values_in.shape[1] # K
2387 W = input.shape[1] # W
2388 indicies_arr = np.int32(
2389 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
2390 ) # (N, W)
2391 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002392
2393 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
2394
Kevin Cheng550ccc52021-03-03 11:21:43 -08002395 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002396 op['op'], [values_in.name, indicies.name, input.name], [result_tens.name]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002397 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08002398
2399 return result_tens
2400
Matthew Haddon848efb42021-09-09 12:30:53 +01002401
Kevin Cheng550ccc52021-03-03 11:21:43 -08002402 def build_resize(
2403 self,
2404 op,
2405 input,
2406 mode,
2407 stride,
2408 offset,
2409 shift,
2410 stride_fp,
2411 offset_fp,
2412 output_dims,
2413 input_dtype,
2414 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01002415 validator_fcns,
2416 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002417 ):
2418 result_tens = OutputShaper.resizeOp(
2419 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002420 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002421 input,
2422 mode,
2423 stride,
2424 offset,
2425 shift,
2426 stride_fp,
2427 offset_fp,
2428 output_dims,
2429 input_dtype,
2430 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01002431 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08002432 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002433
Matthew Haddon848efb42021-09-09 12:30:53 +01002434 # Invalidate Input/Output list for error if checks.
2435 input_list = [input.name]
2436 output_list = [result_tens.name]
2437 pCount, cCount = op["operands"]
2438 num_operands = pCount + cCount
2439 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
Matthew Haddone86fd342021-09-07 16:12:21 +01002440
Matthew Haddon848efb42021-09-09 12:30:53 +01002441 TosaErrorValidator.evValidateErrorIfs(
2442 self.ser,
2443 validator_fcns,
2444 error_name,
2445 op=op,
2446 mode=mode,
2447 shift=shift,
2448 input_dtype=input_dtype,
2449 output_dtype=output_dtype,
2450 input_shape=input,
2451 output_shape=output_dims,
2452 offset=offset,
2453 offset_fp=offset_fp,
2454 stride=stride,
2455 stride_fp=stride_fp,
2456 input_list=input_list,
2457 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002458 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01002459 num_operands=num_operands,
2460 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002461
Eric Kunzee5e26762020-10-13 16:11:07 -07002462 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08002463
Kevin Cheng550ccc52021-03-03 11:21:43 -08002464 attr.ResizeAttribute(
2465 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
2466 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002467
Matthew Haddon848efb42021-09-09 12:30:53 +01002468 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002469 return result_tens
2470
2471 def build_identityn(self, op, val, val2):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002472 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, None)
2473 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, None)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002474 self.ser.addOperator(
2475 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2476 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002477 return result_tens
2478
Kevin Cheng17e92022021-10-01 14:33:33 -07002479 def build_const(self, op, val):
2480 self.ser.addOutputTensor(val)
2481 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07002482
2483 # Type Conversion
2484 def build_cast(self, op, val, out_dtype):
2485 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01002486 self.ser.addOperator(op['op'], [val.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002487 return result_tens
2488
2489 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
2490 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
2491
2492 if per_channel:
2493 nc = val.shape[-1]
2494 else:
2495 nc = 1
2496
2497 in_type_width = self.typeWidth(val.dtype)
2498 out_type_width = self.typeWidth(out_dtype)
2499
Kevin Cheng3a478572021-01-22 17:21:02 -08002500 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002501 input_zp = self.randInt(-128, 128)
2502 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002503 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002504 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07002505 in_type_width = in_type_width + 1
2506 else:
2507 input_zp = 0
2508
Kevin Cheng3a478572021-01-22 17:21:02 -08002509 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002510 output_zp = self.randInt(-128, 128)
2511 out_type_width = out_type_width + 1
2512 elif out_dtype == DType.UINT8:
2513 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07002514 out_type_width = out_type_width + 1
2515 else:
2516 output_zp = 0
2517
2518 # Calculate scale based on:
2519 # scale = a *(2^output_width)/(2^input_width))
2520
2521 a = np.float32(self.rng.random(size=[nc]))
2522 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2523
2524 if scale32:
2525 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002526 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002527 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2528 else:
2529 # Cap the scaling at 2^15 - 1 for scale16
2530 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2531
Kevin Cheng550ccc52021-03-03 11:21:43 -08002532 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002533
2534 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2535 shift_arr = np.int32(np.zeros(shape=[nc]))
2536
2537 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002538 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2539 scale_arr[i], scale32
2540 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002541
Kevin Cheng550ccc52021-03-03 11:21:43 -08002542 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07002543
2544 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002545 attr.RescaleAttribute(
2546 input_zp,
2547 output_zp,
2548 multiplier_arr,
2549 shift_arr,
2550 scale32,
2551 double_round,
2552 per_channel,
2553 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002554
Matthew Haddon848efb42021-09-09 12:30:53 +01002555 self.ser.addOperator(op['op'], [val.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002556 return result_tens
2557
2558 def build_cond_if_const(self, op, then_tens, else_tens, cond):
2559 # For cond_if with constants, we're supplied with then/else tensors that we ignore
2560 # (except for the generated shap) and the condition. Build Then/Else blocks
2561 # and fill them with const nodes for the body.
2562
2563 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002564 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07002565
2566 # Make then/else tensors
2567 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01002568 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2569 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002570
2571 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002572 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002573
2574 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002575 then_block = "THEN_BLOCK"
2576 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002577 attr = ts.TosaSerializerAttribute()
2578 attr.CondIfAttribute(then_block, else_block)
2579
2580 # Finally, build the op and the two blocks
Matthew Haddon848efb42021-09-09 12:30:53 +01002581 self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002582
2583 self.ser.startBasicBlock(then_block)
2584 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002585 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002586 self.ser.addOutputTensor(then_tens)
2587
2588 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002589 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002590 self.ser.addOutputTensor(else_tens)
2591
2592 return result_tens
2593
2594 def build_cond_if_binary(self, op, a, b, cond):
2595 # For cond_if with a binary op in the then/else blocks, take a and b and
2596 # alternately add or subtract them based on the condition
2597
2598 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002599 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07002600
Kevin Cheng550ccc52021-03-03 11:21:43 -08002601 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002602
2603 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002604 then_block = "THEN_BLOCK"
2605 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002606 attr = ts.TosaSerializerAttribute()
2607 attr.CondIfAttribute(then_block, else_block)
2608
2609 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002610 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002611 op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002612 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002613
2614 self.ser.startBasicBlock(then_block)
2615 self.ser.addInputTensor(a)
2616 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002617 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002618 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
2619
2620 self.ser.startBasicBlock(else_block)
2621 self.ser.addInputTensor(a)
2622 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002623 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002624 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
2625
2626 return result_tens
2627
2628 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002629 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002630
Kevin Cheng550ccc52021-03-03 11:21:43 -08002631 cond_block = "COND_BLOCK"
2632 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002633
2634 attr = ts.TosaSerializerAttribute()
2635 attr.WhileLoopAttribute(cond_block, body_block)
2636
2637 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002638 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002639 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002640 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002641
2642 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002643 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2644 a_out = self.ser.addIntermediate(a.shape, a.dtype)
2645 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002646
2647 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002648 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002649 op['op'],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002650 [iter.name, a.name, acc.name],
2651 [iter_out.name, a_out.name, acc_out.name],
2652 attr,
2653 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002654 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002655
2656 # COND block (input: iter, output: cond_tens )
2657 self.ser.startBasicBlock(cond_block)
2658 self.ser.addInputTensor(iter)
2659 self.ser.addInputTensor(a)
2660 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002661 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
2662 cond_tens = self.ser.addOutput([], DType.BOOL)
2663 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002664
2665 # BODY block (input: a, acc, iter, output: a, acc, iter)
2666 # Note that local intermediate tensors need to be declared here for the outputs
2667 self.ser.startBasicBlock(body_block)
2668 self.ser.addInputTensor(iter)
2669 self.ser.addInputTensor(a)
2670 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002671 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
2672 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2673 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002674 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2675 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2676 self.ser.addOutputTensor(iter_body_out)
2677 self.ser.addOutputTensor(a)
2678 self.ser.addOutputTensor(acc_body_out)
2679
2680 return acc_out
2681
Matthew Haddon1c00b712021-10-01 15:51:03 +01002682 def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
2683 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2684 default_test_rank_range = range(1, 5)
2685 if not shapeFilter:
2686 shapeFilter = [None]
2687
2688 # Calculate the filters based on what is requested and what the operator allows
2689 rmin, rmax = op["rank"]
2690 if rankFilter is not None:
2691 cleanRankFilter = []
2692 # Ensure rankFilter values are allowed by operator
2693 for rank in rankFilter:
2694 if rank >= rmin and rank <= rmax:
2695 cleanRankFilter.append(rank)
2696 elif rankFilter is None and shapeFilter[0] is None:
2697 cleanRankFilter = []
2698 # Ensure default behaviour is bounded by default range or by operator, whichever is smaller.
2699 rankRange = range(rmin, rmax + 1)
2700 for rank in rankRange:
2701 if rank >= min(default_test_rank_range) and rank <= max(default_test_rank_range):
2702 cleanRankFilter.append(rank)
2703 else:
2704 cleanRankFilter = range(rmin, rmax + 1)
2705
2706 dtypes = op["types"]
2707 if dtypeFilter is not None:
2708 cleanDtypeFilter = []
2709 # Ensure filtered dtypes are allowed by operator
2710 for dtype in dtypeFilter:
2711 if dtype in dtypes:
2712 cleanDtypeFilter.append(dtype)
2713 else:
2714 cleanDtypeFilter = dtypes
2715
2716 if testType == 'positive':
2717 filterDict = {
2718 'shapeFilter': shapeFilter,
2719 'rankFilter': cleanRankFilter,
2720 'dtypeFilter': cleanDtypeFilter
2721 }
2722 return filterDict
2723 elif testType == 'negative':
2724 validator_info = validator(check=False, op=op)
2725 error_arguments = validator_info['param_reqs']
2726
2727 #Set parameters as required
2728 if error_arguments['rank'] != None:
2729 rankFilter = error_arguments['rank']
2730 else:
2731 rankFilter = cleanRankFilter
2732
2733 if error_arguments['dtype'] != None:
2734 dtypeFilter = error_arguments['dtype']
2735 else:
2736 dtypeFilter = cleanDtypeFilter
2737
2738 if error_arguments['shape'] != None:
2739 shapeFilter = error_arguments['shape']
2740 else:
2741 shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small
2742
2743 filterDict = {
2744 'shapeFilter': shapeFilter,
2745 'rankFilter': rankFilter,
2746 'dtypeFilter': dtypeFilter
2747 }
2748 return filterDict
2749
2750
Kevin Cheng550ccc52021-03-03 11:21:43 -08002751 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01002752 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08002753 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002754
2755 try:
2756 op = self.TOSA_OP_LIST[opName]
2757 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002758 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002759
2760 # Initialize a new random number generator
2761 self.rng = np.random.default_rng(self.random_seed)
2762
Kevin Cheng550ccc52021-03-03 11:21:43 -08002763 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002764
Eric Kunzee5e26762020-10-13 16:11:07 -07002765 # Test list consists of a tuple of:
2766 # (opName, testNameStr, dtype, shapeList, argumentsList)
2767 testList = []
Matthew Haddon1c00b712021-10-01 15:51:03 +01002768 if testType == 'negative' and "error_if_validators" in op:
2769 error_if_validators = op["error_if_validators"]
2770 else:
2771 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002772
Matthew Haddon1c00b712021-10-01 15:51:03 +01002773 for validator in error_if_validators:
2774 if validator is not None:
2775 error_name = validator(check=False, op=op)['error_name']
2776 #print("error_name: ", error_name)
2777 else:
2778 error_name = None
2779
2780 filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator)
2781 cleanRankFilter = filterDict['rankFilter']
2782 cleanDtypeFilter = filterDict['dtypeFilter']
2783 cleanShapeFilter = filterDict['shapeFilter']
2784 #print(f"Filters: S {shapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
2785
2786 for r in cleanRankFilter:
Kevin Cheng1533b852021-09-01 12:51:58 -07002787 if opName.startswith("conv3d"):
2788 assert r == 5, "conv3d test must have input rank == 5"
Matthew Haddon1c00b712021-10-01 15:51:03 +01002789 for t in cleanDtypeFilter:
2790 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002791 # Filter out by rank
2792 if shape is not None and len(shape) != r:
2793 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002794 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002795 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002796
Matthew Haddon74567092021-07-16 15:38:20 +01002797 shapeStr = self.shapeStr(shapeList[0])
2798 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002799
Matthew Haddon74567092021-07-16 15:38:20 +01002800 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2801 argList = []
2802 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002803 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002804 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002805 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002806
Matthew Haddon74567092021-07-16 15:38:20 +01002807 for argStr, args in argList:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002808 if testType == 'positive':
2809 if argStr:
2810 testStr = "{}_{}_{}_{}".format(
2811 opName, shapeStr, typeStr, argStr
2812 )
2813 else:
2814 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
2815 elif testType == 'negative':
Matthew Haddone86fd342021-09-07 16:12:21 +01002816 if argStr:
2817 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2818 opName, error_name, shapeStr, typeStr, argStr
2819 )
2820 else:
2821 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002822
2823 testList.append((opName, testStr, t, error_name, shapeList, args))
2824
2825 if testType == 'positive':
2826 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2827 if "invalid_test_validators" in op:
2828 invalid_test_validators = op["invalid_test_validators"]
2829 clean_testList = []
2830 for test in testList:
2831 for validator_fcn in invalid_test_validators:
2832 remove_test = False
2833 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
2834 remove_test = True
2835 if not remove_test:
2836 clean_testList.append(test)
2837 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002838
2839 return testList
2840
Matthew Haddone86fd342021-09-07 16:12:21 +01002841
2842 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07002843 try:
2844 op = self.TOSA_OP_LIST[opName]
2845 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002846 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002847
2848 # Create a serializer
2849 self.createSerializer(opName, testStr)
2850
Kevin Cheng550ccc52021-03-03 11:21:43 -08002851 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002852 if "error_if_validators" in op:
2853 error_if_validators = op["error_if_validators"]
2854 else:
2855 error_if_validators = None
2856
Kevin Cheng550ccc52021-03-03 11:21:43 -08002857 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002858 num_operands = pCount + cCount
2859
2860 if isinstance(dtype_or_dtypeList, list):
2861 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002862 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002863 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002864 else:
2865 dtypeList = [dtype_or_dtypeList] * (num_operands)
2866
Kevin Cheng93a16282021-08-31 16:14:03 -07002867 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002868 assert (
2869 len(shapeList) == num_operands
2870 ), "shapeList length {} must match number of operands {}".format(
2871 len(shapeList), num_operands
2872 )
2873 assert (
2874 len(dtypeList) == num_operands
2875 ), "dtypeList length {} must match number of operands {}".format(
2876 len(dtypeList), num_operands
2877 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002878
2879 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002880 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002881 except KeyError:
2882 qgen = None
2883
2884 # Build the random tensor operands and the test
2885 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002886
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002887 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002888
2889 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002890 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002891 else:
2892 qinfo = None
2893
2894 try:
2895 if error_if_validators is None:
2896 if qinfo is not None:
2897 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2898 else:
2899 resultName = build_fcn(self, op, *tens, *testArgs)
2900 else:
2901 if qinfo is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002902 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name, qinfo)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002903 else:
2904 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name)
2905 except TypeError as e:
2906 print(
2907 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
2908 build_fcn, tens, testArgs
2909 )
2910 )
2911 raise e
2912
2913 if resultName is None:
2914 print("Invalid ERROR_IF tests created")
2915
2916 # Save the serialized test
2917 self.serialize("test")
2918
2919
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002920 def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002921 pCount, cCount = op["operands"]
2922
2923 tens = []
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002924 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name == None:
Jeremy Johnsonef509a42021-09-07 13:59:47 +01002925 # Make sure the operation does not cause value saturation - where
2926 # the number wraps due to limited number of bits to store the answer
2927 assert (
2928 pCount == 2 and cCount == 0
2929 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01002930 placeholders = []
2931 add = (op["op"] == Op.ADD)
2932 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
2933 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
2934 if add:
2935 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
2936 else:
2937 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
2938
2939 # Work out the saturation limits
2940 max_i32 = (1 << 31)-1
2941 min_i32 = -(1 << 31)
2942 max_arr = np.full(shapeList[1], max_i32)
2943 min_arr = np.full(shapeList[1], min_i32)
2944
2945 # Find how much values exceed the maximum/minimums
2946 sat_max_arr = np.maximum(res_arr - max_arr, 0)
2947 sat_min_arr = np.minimum(res_arr - min_arr, 0)
2948
2949 if not add:
2950 # Swap saturation values and negate values as we need to perform opposite operations
2951 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
2952
2953 # Create new array of unsaturated values by clipping values as needed
2954 b_unsat_arr = b_arr
2955 if (sat_max_arr != 0).any():
2956 # Clip values that cause saturation
2957 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
2958 # Reduce axes in unsaturated tensor to match original tensor
2959 for axis, dim in enumerate(b_arr.shape):
2960 if dim != b_unsat_arr.shape[axis]:
2961 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
2962 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
2963
2964 if (sat_min_arr != 0).any():
2965 # Clip values that cause saturation
2966 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
2967 # Reduce axes in unsaturated tensor to match original tensor
2968 for axis, dim in enumerate(b_arr.shape):
2969 if dim != b_unsat_arr.shape[axis]:
2970 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
2971 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
2972
2973 placeholders.append(
2974 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
2975 )
2976 placeholders.append(
2977 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
2978 )
2979
2980 tens.extend(placeholders)
2981 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
2982 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002983 assert (
2984 pCount == 2 and cCount == 0
2985 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08002986
2987 placeholders = []
2988 for idx, shape in enumerate(shapeList[:]):
2989 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07002990 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002991 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07002992 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002993 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07002994 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002995 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
2996 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002997 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002998 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002999 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07003000 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08003001
3002 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01003003 elif op["op"] == Op.SELECT:
3004 # Set datatype of condition tensor to boolean
3005 dtypeList[0] = DType.BOOL
3006 tens.extend(
3007 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
3008 )
3009 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003010 elif op["op"] == Op.INTDIV and error_name == None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003011 assert (
3012 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01003013 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003014
3015 placeholders = []
3016
Matthew Haddon459443c2021-08-23 16:43:13 +01003017 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003018 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07003019 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003020 while True:
3021 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
3022 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
3023
3024 if (divisor_arr == 0).any():
3025 continue
3026
Kevin Cheng47315e12021-05-13 17:41:28 -07003027 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003028 continue
3029
3030 break
3031
3032 placeholders.append(
3033 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
3034 )
3035 placeholders.append(
3036 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
3037 )
3038
3039 tens.extend(placeholders)
3040 elif op["op"] == Op.MUL:
3041 assert (
3042 pCount == 2 and cCount == 0
3043 ), "Op.MUL must have 2 placeholders, 0 consts"
3044
3045 if dtypeList[0] == DType.FLOAT:
3046 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
3047 else:
3048 placeholders = []
3049
3050 # Make sure multiply result in int32 range
3051 shift = testArgs[0]
3052 if dtypeList[0] == DType.INT8:
3053 num_bits = 8
3054 elif dtypeList[0] == DType.INT16:
3055 num_bits = 16
3056 elif dtypeList[0] == DType.INT32:
3057 num_bits = 32
3058 else:
3059 raise Exception("OpMul: invalid input dtype")
3060
3061 for idx, shape in enumerate(shapeList[:]):
3062 low = -(2 ** (num_bits - 1))
3063 high = (2 ** (num_bits - 1)) - 1
3064
3065 a_arr = np.int32(
3066 self.rng.integers(low=low, high=high, size=shapeList[0])
3067 )
3068 b_arr = np.int32(
3069 self.rng.integers(low=low, high=high, size=shapeList[1])
3070 )
3071
3072 i = 0
3073 while True:
3074
3075 a_arr_64 = a_arr.astype(np.int64)
3076 b_arr_64 = b_arr.astype(np.int64)
3077
3078 if shift > 0:
3079 rounding = 1 << (shift - 1)
3080 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
3081 else:
3082 result_arr = a_arr_64 * b_arr_64
3083
3084 if (result_arr > -(2 ** 31)).all() and (
3085 result_arr <= ((2 ** 31) - 1)
3086 ).all():
3087 break
3088
3089 i = i + 1
3090 a_arr = a_arr // 2
3091 b_arr = b_arr // 2
3092
3093 placeholders.append(
3094 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
3095 )
3096 placeholders.append(
3097 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
3098 )
3099
3100 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01003101 elif op["op"] == Op.CONCAT:
3102 count = len(shapeList) - self.args.num_const_inputs_concat
3103 if count < 1:
3104 count = 1
3105 if self.args.num_const_inputs_concat == 0:
3106 count = len(shapeList)
3107
3108 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
3109 tens.extend(
3110 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
3111 )
3112 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08003113 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07003114 tens.extend(
3115 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
3116 )
3117 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07003118
Matthew Haddon1c00b712021-10-01 15:51:03 +01003119 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07003120
3121 def createDynamicOpLists(self):
3122
3123 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07003124 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003125
Kevin Cheng1533b852021-09-01 12:51:58 -07003126 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003127 testName = "conv2d_{}x{}".format(k[0], k[1])
3128 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3129 self.TOSA_OP_LIST[testName]["filter"] = k
3130 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003131
Kevin Cheng550ccc52021-03-03 11:21:43 -08003132 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3133 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3134 "depthwise_conv2d_TEMPLATE"
3135 ].copy()
3136 self.TOSA_OP_LIST[testName]["filter"] = k
3137 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003138
Kevin Cheng550ccc52021-03-03 11:21:43 -08003139 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3140 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3141 "transpose_conv2d_TEMPLATE"
3142 ].copy()
3143 self.TOSA_OP_LIST[testName]["filter"] = k
3144 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003145
Kevin Cheng1533b852021-09-01 12:51:58 -07003146 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3147 for k in KERNELS_3D:
3148 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3149 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3150 self.TOSA_OP_LIST[testName]["filter"] = k
3151 self.TOSA_OP_LIST[testName]["template"] = False
3152
Eric Kunzee5e26762020-10-13 16:11:07 -07003153 # Delete any templates after having created any dynamic ops
3154 # This is a two-pass operation because it's bad practice to delete
3155 # keys from dictionaries while iterating
3156 keyList = []
3157 for k in self.TOSA_OP_LIST:
3158 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003159 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07003160 keyList.append(k)
3161 continue
3162 except KeyError:
3163 pass
3164
3165 for k in keyList:
3166 del self.TOSA_OP_LIST[k]
3167
3168 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003169 """Fill in default fields for ops if they aren't already specified.
3170 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003171 for op in self.TOSA_OP_LIST:
3172
3173 # Required fields
3174 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003175 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003176 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003177 raise Exception(
3178 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3179 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003180
3181 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003182 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003183 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003184 raise Exception(
3185 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3186 op
3187 )
3188 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003189
3190 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003191 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003192 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003193 raise Exception(
3194 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3195 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003196
3197 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003198 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003199 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003200 raise Exception(
3201 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3202 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003203
3204 # Put in default rank range, if missing
3205 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003206 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003207 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003208 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003209
3210 # Tensor operator list
3211 # 'op': op name
3212 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003213 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3214 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003215 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3216 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08003217 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07003218
Kevin Cheng550ccc52021-03-03 11:21:43 -08003219 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
3220 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003221
Kevin Cheng550ccc52021-03-03 11:21:43 -08003222 TYPE_BOOL = [DType.BOOL]
3223 TYPE_FI32 = [DType.FLOAT, DType.INT32]
3224 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
3225 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003226
Kevin Cheng550ccc52021-03-03 11:21:43 -08003227 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07003228
Kevin Cheng1533b852021-09-01 12:51:58 -07003229 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003230 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003231 [DType.INT8, DType.INT8, DType.INT32],
3232 [DType.INT16, DType.INT8, DType.INT48],
3233 DType.FLOAT,
3234 ]
3235
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003236 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003237
3238 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003239 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003240 "argmax": {
3241 "op": Op.ARGMAX,
3242 "operands": (1, 0),
3243 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3244 "types": TYPE_NARROW_INT_FP,
3245 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003246 "avg_pool2d": {
3247 "op": Op.AVG_POOL2D,
3248 "operands": (1, 0),
3249 "rank": (4, 4),
3250 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
3251 "qgen": TosaQuantGen.qgUnary,
3252 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003253 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
Jared Smolens573ecd42021-03-04 15:24:10 -08003254 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003255 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003256 "conv2d_TEMPLATE": {
3257 "op": Op.CONV2D,
3258 "operands": (1, 2),
3259 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01003260 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003261 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003262 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003263 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003264 "template": True,
3265 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003266 # Templated operator. Filled in by createDynamicOpLists
3267 "conv3d_TEMPLATE": {
3268 "op": Op.CONV3D,
3269 "operands": (1, 2),
3270 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01003271 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07003272 "qgen": TosaQuantGen.qgConv,
3273 "types": TYPE_CONV,
3274 "template": True,
3275 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003276 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003277 "depthwise_conv2d_TEMPLATE": {
3278 "op": Op.DEPTHWISE_CONV2D,
3279 "operands": (1, 2),
3280 "filter": [1, 1],
3281 "rank": (4, 4),
3282 "build_fcn": (
3283 build_depthwise_conv2d,
3284 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01003285 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003286 ),
3287 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003288 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003289 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003290 "template": True,
3291 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003292 "fully_connected": {
3293 "op": Op.FULLY_CONNECTED,
3294 "operands": (1, 2),
3295 "rank": (2, 2),
3296 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
3297 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003298 "types": TYPE_CONV,
Jared Smolens573ecd42021-03-04 15:24:10 -08003299 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003300 "matmul": {
3301 "op": Op.MATMUL,
3302 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003303 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08003304 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
3305 "qgen": TosaQuantGen.qgMatmul,
3306 "types": TYPE_NARROW_INT_FP,
3307 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003308 "max_pool2d": {
3309 "op": Op.MAX_POOL2D,
3310 "operands": (1, 0),
3311 "rank": (4, 4),
3312 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
3313 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003314 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
Jared Smolens573ecd42021-03-04 15:24:10 -08003315 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003316 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003317 "transpose_conv2d_TEMPLATE": {
3318 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003319 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003320 "rank": (4, 4),
3321 "build_fcn": (
3322 build_transpose_conv2d,
3323 TosaTensorGen.tgTransposeConv2D,
3324 TosaArgGen.agTransposeConv2D,
3325 ),
3326 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003327 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003328 "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003329 "template": True,
3330 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003331 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003332 "clamp": {
3333 "op": Op.CLAMP,
3334 "operands": (1, 0),
3335 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
3336 "types": TYPE_NARROW_INT_FP,
3337 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003338 "sigmoid": {
3339 "op": Op.SIGMOID,
3340 "operands": (1, 0),
3341 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
3342 "types": TYPE_FP,
3343 },
3344 "tanh": {
3345 "op": Op.TANH,
3346 "operands": (1, 0),
3347 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
3348 "types": TYPE_FP,
3349 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003350 # Elementwise Binary Operators
3351 "add": {
3352 "op": Op.ADD,
3353 "operands": (2, 0),
3354 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3355 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003356 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3357 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003358 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003359 "arithmetic_right_shift": {
3360 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3361 "operands": (2, 0),
3362 "build_fcn": (
3363 build_arithmetic_right_shift,
3364 TosaTensorGen.tgBroadcastFuzz,
3365 TosaArgGen.agArithmeticRightShift,
3366 ),
3367 "types": TYPE_INT,
3368 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003369 "bitwise_and": {
3370 "op": Op.BITWISE_AND,
3371 "operands": (2, 0),
3372 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3373 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003374 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3375 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003376 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003377 "bitwise_or": {
3378 "op": Op.BITWISE_OR,
3379 "operands": (2, 0),
3380 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3381 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003382 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3383 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003384 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003385 "bitwise_xor": {
3386 "op": Op.BITWISE_XOR,
3387 "operands": (2, 0),
3388 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3389 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003390 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3391 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003392 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003393 "intdiv": {
3394 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003395 "operands": (2, 0),
3396 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3397 "types": [DType.INT32],
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003398 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3399 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003400 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003401 "logical_and": {
3402 "op": Op.LOGICAL_AND,
3403 "operands": (2, 0),
3404 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3405 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003406 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3407 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003408 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003409 "logical_left_shift": {
3410 "op": Op.LOGICAL_LEFT_SHIFT,
3411 "operands": (2, 0),
3412 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3413 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003414 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3415 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003416 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003417 "logical_right_shift": {
3418 "op": Op.LOGICAL_RIGHT_SHIFT,
3419 "operands": (2, 0),
3420 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3421 "types": TYPE_INT,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003422 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3423 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003424 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003425 "logical_or": {
3426 "op": Op.LOGICAL_OR,
3427 "operands": (2, 0),
3428 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3429 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003430 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3431 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003432 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003433 "logical_xor": {
3434 "op": Op.LOGICAL_XOR,
3435 "operands": (2, 0),
3436 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3437 "types": TYPE_BOOL,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003438 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3439 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003440 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003441 "maximum": {
3442 "op": Op.MAXIMUM,
3443 "operands": (2, 0),
3444 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3445 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003446 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3447 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003448 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003449 "minimum": {
3450 "op": Op.MINIMUM,
3451 "operands": (2, 0),
3452 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3453 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003454 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3455 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003456 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003457 "mul": {
3458 "op": Op.MUL,
3459 "operands": (2, 0),
3460 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
3461 "types": TYPE_INT_FP,
3462 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003463 "pow": {
3464 "op": Op.POW,
3465 "operands": (2, 0),
3466 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
3467 "types": TYPE_FP,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003468 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3469 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003470 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003471 "sub": {
3472 "op": Op.SUB,
3473 "operands": (2, 0),
3474 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3475 "types": TYPE_FI32,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003476 "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3477 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003478 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003479 "table": {
3480 "op": Op.TABLE,
3481 # Use the automatic generation functions to create the input array
3482 # but create the table tensor in the build function, as it may be
3483 # a different type from the input
3484 "operands": (1, 0),
3485 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003486 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08003487 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003488 # Elementwise Unary operators
3489 "abs": {
3490 "op": Op.ABS,
3491 "operands": (1, 0),
3492 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3493 "types": TYPE_FI32,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003494 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3495 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003496 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003497 "bitwise_not": {
3498 "op": Op.BITWISE_NOT,
3499 "operands": (1, 0),
3500 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3501 "types": TYPE_INT,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003502 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3503 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003504 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003505 "ceil": {
3506 "op": Op.CEIL,
3507 "operands": (1, 0),
3508 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3509 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003510 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3511 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003512 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003513 "clz": {
3514 "op": Op.CLZ,
3515 "operands": (1, 0),
3516 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3517 "types": [DType.INT32],
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003518 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3519 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003520 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003521 "exp": {
3522 "op": Op.EXP,
3523 "operands": (1, 0),
3524 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3525 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003526 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3527 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003528 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003529 "floor": {
3530 "op": Op.FLOOR,
3531 "operands": (1, 0),
3532 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3533 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003534 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3535 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003536 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003537 "log": {
3538 "op": Op.LOG,
3539 "operands": (1, 0),
3540 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3541 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003542 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3543 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003544 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003545 "logical_not": {
3546 "op": Op.LOGICAL_NOT,
3547 "operands": (1, 0),
3548 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3549 "types": TYPE_BOOL,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003550 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3551 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003552 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003553 "negate": {
3554 "op": Op.NEGATE,
3555 "operands": (1, 0),
3556 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3557 "qgen": TosaQuantGen.qgUnary,
3558 "types": TYPE_INT_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003559 "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
3560 TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
3561 TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003562 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003563 "reciprocal": {
3564 "op": Op.RECIPROCAL,
3565 "operands": (1, 0),
3566 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3567 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003568 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3569 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003570 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003571 "rsqrt": {
3572 "op": Op.RSQRT,
3573 "operands": (1, 0),
3574 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3575 "types": TYPE_FP,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003576 "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
3577 TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Jared Smolens573ecd42021-03-04 15:24:10 -08003578 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003579 # Elementwise Ternary operators
3580 "select": {
3581 "op": Op.SELECT,
3582 "operands": (3, 0),
3583 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
3584 "types": TYPE_FIB,
3585 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003586 # Comparison operators
3587 "equal": {
3588 "op": Op.EQUAL,
3589 "operands": (2, 0),
3590 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3591 "types": TYPE_FI32,
3592 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003593 "greater_equal": {
3594 "op": Op.GREATER_EQUAL,
3595 "operands": (2, 0),
3596 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3597 "types": TYPE_FI32,
3598 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003599 "greater": {
3600 "op": Op.GREATER,
3601 "operands": (2, 0),
3602 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3603 "types": TYPE_FI32,
3604 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003605 # Reduction operators
3606 "reduce_all": {
3607 "op": Op.REDUCE_ALL,
3608 "operands": (1, 0),
3609 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3610 "types": TYPE_BOOL,
3611 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003612 "reduce_any": {
3613 "op": Op.REDUCE_ANY,
3614 "operands": (1, 0),
3615 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3616 "types": TYPE_BOOL,
3617 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003618 "reduce_max": {
3619 "op": Op.REDUCE_MAX,
3620 "operands": (1, 0),
3621 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3622 "types": TYPE_INT_FP,
3623 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003624 "reduce_min": {
3625 "op": Op.REDUCE_MAX,
3626 "operands": (1, 0),
3627 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3628 "types": TYPE_INT_FP,
3629 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003630 "reduce_product": {
3631 "op": Op.REDUCE_PRODUCT,
3632 "operands": (1, 0),
3633 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3634 "types": TYPE_FP,
3635 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003636 "reduce_sum": {
3637 "op": Op.REDUCE_SUM,
3638 "operands": (1, 0),
3639 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3640 "types": TYPE_FI32,
3641 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003642 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003643 "concat": {
3644 "op": Op.CONCAT,
3645 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01003646 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003647 "types": TYPE_FIB,
3648 },
3649 "pad": {
3650 "op": Op.PAD,
3651 "operands": (1, 0),
3652 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
3653 "qgen": TosaQuantGen.qgPad,
3654 "types": TYPE_FIB,
3655 },
3656 "reshape": {
3657 "op": Op.RESHAPE,
3658 "operands": (1, 0),
3659 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
3660 "types": TYPE_FIB,
3661 },
3662 "reverse": {
3663 "op": Op.REVERSE,
3664 "operands": (1, 0),
3665 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3666 "types": TYPE_FIB,
3667 },
3668 "slice": {
3669 "op": Op.SLICE,
3670 "operands": (1, 0),
3671 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
3672 "types": TYPE_FIB,
3673 },
3674 "tile": {
3675 "op": Op.TILE,
3676 "operands": (1, 0),
3677 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
3678 "types": TYPE_FIB,
3679 },
3680 "transpose": {
3681 "op": Op.TRANSPOSE,
3682 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003683 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003684 "build_fcn": (
3685 build_transpose,
3686 TosaTensorGen.tgBasic,
3687 TosaArgGen.agTranspose,
3688 ),
3689 "types": TYPE_FIB,
3690 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003691 # Data nodes
3692 "const": {
3693 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003694 "operands": (0, 1),
3695 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08003696 "types": TYPE_FIB,
3697 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003698 "identity": {
3699 "op": Op.IDENTITY,
3700 "operands": (1, 0),
3701 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3702 "types": TYPE_FIB,
3703 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003704 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003705 "gather": {
3706 "op": Op.GATHER,
3707 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3708 "operands": (1, 0),
3709 "rank": (3, 3),
3710 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
3711 "types": TYPE_INT_FP,
3712 },
3713 "scatter": {
3714 "op": Op.SCATTER,
3715 # Only specify 'values_in' tensor here.
3716 #'indices' and 'input' are generated in op building stage
3717 "operands": (2, 0),
3718 "rank": (3, 3),
3719 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
3720 "types": TYPE_INT_FP,
3721 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003722 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003723 "resize": {
3724 "op": Op.RESIZE,
3725 "operands": (1, 0),
3726 "rank": (4, 4),
3727 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
3728 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01003729 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
3730 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
3731 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
Matthew Haddon848efb42021-09-09 12:30:53 +01003732 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01003733 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
3734 TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003735 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003736 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003737 "cast": {
3738 "op": Op.CAST,
3739 "operands": (1, 0),
3740 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
3741 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
3742 },
3743 "rescale": {
3744 "op": Op.RESCALE,
3745 "operands": (1, 0),
3746 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003747 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Kevin Cheng550ccc52021-03-03 11:21:43 -08003748 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003749 # Custom
3750 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003751 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003752 # Two varients of cond_if, one that generates one of two constant tensors (no
3753 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3754 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003755 "cond_if_const": {
3756 "op": Op.COND_IF,
3757 "operands": (0, 2),
3758 "build_fcn": (
3759 build_cond_if_const,
3760 TosaTensorGen.tgBasic,
3761 TosaArgGen.agCondIf,
3762 ),
3763 "types": [DType.BOOL],
3764 },
3765 "cond_if_binary": {
3766 "op": Op.COND_IF,
3767 "operands": (2, 0),
3768 "build_fcn": (
3769 build_cond_if_binary,
3770 TosaTensorGen.tgBasic,
3771 TosaArgGen.agCondIf,
3772 ),
3773 "types": TYPE_FI32,
3774 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003775 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003776 "while_loop": {
3777 "op": Op.WHILE_LOOP,
3778 "operands": (0, 1),
3779 "build_fcn": (
3780 build_while_loop,
3781 TosaTensorGen.tgBasic,
3782 TosaArgGen.agWhileLoop,
3783 ),
3784 "types": [DType.INT32],
3785 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003786 }
3787
Kevin Cheng550ccc52021-03-03 11:21:43 -08003788
Eric Kunzee5e26762020-10-13 16:11:07 -07003789class OutputShaper:
3790 # Methods in this class compute the expected output shape and datatype
3791 # for common classes of operations
3792 def __init__(self):
3793 pass
3794
3795 # These methods return arguments that can be used for
3796 # creating a new output tensor
3797 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003798 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
3799 if error_name != ErrorIf.RankMismatch:
3800 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003801 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003802
3803 shape = []
3804 for i in range(len(a.shape)):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003805 if a.shape[i] == 1 and error_name == None:
Eric Kunzee5e26762020-10-13 16:11:07 -07003806 shape.append(b.shape[i])
3807 else:
3808 shape.append(a.shape[i])
3809
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003810 if error_name == ErrorIf.WrongOutputType:
3811 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
3812 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3813 outputDType = rng.choice(wrong_dtypes)
3814 else:
3815 outputDType = a.dtype
3816
3817 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003818
3819 @staticmethod
3820 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003821 assert len(a.shape) == len(b.shape)
3822 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003823
3824 shape = []
3825 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003826 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07003827 shape.append(a.shape[i])
3828
Kevin Cheng550ccc52021-03-03 11:21:43 -08003829 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003830
3831 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003832 def unaryOp(ser, rng, a, error_name=None):
3833 if error_name == ErrorIf.WrongOutputType:
3834 all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
3835 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3836 outputDType = rng.choice(wrong_dtypes)
3837 else:
3838 outputDType = a.dtype
3839
3840 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003841
3842 @staticmethod
3843 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003844 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
3845 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003846
3847 shape = []
3848 for i in range(len(a.shape)):
3849 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
3850
Kevin Cheng550ccc52021-03-03 11:21:43 -08003851 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003852
3853 @staticmethod
3854 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003855 assert len(a.shape) == len(b.shape)
3856 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003857
3858 # Do broadcast
3859 shape = []
3860 for i in range(len(a.shape)):
3861 if a.shape[i] == 1:
3862 shape.append(b.shape[i])
3863 else:
3864 shape.append(a.shape[i])
3865
3866 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08003867 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07003868
3869 @staticmethod
3870 def reduceOp(ser, a, axis):
3871
3872 shape = a.shape.copy()
3873
3874 shape[axis] = 1
3875
Kevin Cheng550ccc52021-03-03 11:21:43 -08003876 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003877
3878 @staticmethod
3879 def argmaxOp(ser, a, axis):
3880 shape = a.shape.copy()
3881 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003882 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07003883
3884 @staticmethod
3885 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
3886
3887 # IFM: NHWC
3888 # Filter: OHWI
3889 # OFM: NHWC
3890
3891 if len(padding) == 2:
3892 # Expand padding to 4 parameters in the case of transpose_conv2d
3893 # From H,W to T,B,L,R
3894 padding = [padding[0], padding[0], padding[1], padding[1]]
3895
Kevin Cheng550ccc52021-03-03 11:21:43 -08003896 h = (
3897 ifm.shape[1]
3898 - filter.shape[1]
3899 - (filter.shape[1] - 1) * (dilations[0] - 1)
3900 + padding[0]
3901 + padding[1]
3902 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003903
Kevin Cheng550ccc52021-03-03 11:21:43 -08003904 w = (
3905 ifm.shape[2]
3906 - filter.shape[2]
3907 - (filter.shape[2] - 1) * (dilations[1] - 1)
3908 + padding[2]
3909 + padding[3]
3910 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003911
Eric Kunzee5e26762020-10-13 16:11:07 -07003912 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
3913
Kevin Cheng3a478572021-01-22 17:21:02 -08003914 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003915 out_dtype = DType.INT32
3916 elif ifm.dtype == DType.INT16:
3917 out_dtype = DType.INT48
3918 elif ifm.dtype == DType.FLOAT:
3919 out_dtype = DType.FLOAT
3920 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003921 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003922
Kevin Cheng550ccc52021-03-03 11:21:43 -08003923 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003924
3925 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -07003926 def conv3dOp(ser, ifm, filter, strides, padding, dilations):
3927
3928 # IFM: NDHWC
3929 # Filter: ODHWI
3930 # OFM: NDHWC
3931
3932 d = (
3933 ifm.shape[1]
3934 - filter.shape[1]
3935 - (filter.shape[1] - 1) * (dilations[0] - 1)
3936 + padding[0]
3937 + padding[1]
3938 ) // strides[0] + 1
3939
3940 h = (
3941 ifm.shape[2]
3942 - filter.shape[2]
3943 - (filter.shape[2] - 1) * (dilations[1] - 1)
3944 + padding[2]
3945 + padding[3]
3946 ) // strides[1] + 1
3947
3948 w = (
3949 ifm.shape[3]
3950 - filter.shape[3]
3951 - (filter.shape[3] - 1) * (dilations[2] - 1)
3952 + padding[4]
3953 + padding[5]
3954 ) // strides[2] + 1
3955
3956 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
3957
3958 if ifm.dtype == DType.INT8:
3959 out_dtype = DType.INT32
3960 elif ifm.dtype == DType.INT16:
3961 out_dtype = DType.INT48
3962 elif ifm.dtype == DType.FLOAT:
3963 out_dtype = DType.FLOAT
3964 else:
3965 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
3966
3967 return ser.addOutput(ofm_shape, out_dtype)
3968
3969 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -07003970 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
3971 # IFM: NHWC
3972 # Filter: HWCM
3973 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08003974 h = (
3975 ifm.shape[1]
3976 - filter.shape[0]
3977 - (filter.shape[0] - 1) * (dilations[0] - 1)
3978 + padding[0]
3979 + padding[1]
3980 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003981
Kevin Cheng550ccc52021-03-03 11:21:43 -08003982 w = (
3983 ifm.shape[2]
3984 - filter.shape[1]
3985 - (filter.shape[1] - 1) * (dilations[1] - 1)
3986 + padding[2]
3987 + padding[3]
3988 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003989
Eric Kunzee5e26762020-10-13 16:11:07 -07003990 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
3991
Kevin Cheng3a478572021-01-22 17:21:02 -08003992 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003993 out_dtype = DType.INT32
3994 elif ifm.dtype == DType.INT16:
3995 out_dtype = DType.INT48
3996 elif ifm.dtype == DType.FLOAT:
3997 out_dtype = DType.FLOAT
3998 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003999 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004000
Kevin Cheng550ccc52021-03-03 11:21:43 -08004001 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004002
4003 @staticmethod
4004 def pool2dOp(ser, ifm, kernel, stride, pad):
4005 # input: NHWC
4006 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
4007 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
4008
Eric Kunzee5e26762020-10-13 16:11:07 -07004009 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004010 return ser.addOutput(ofm_shape, ifm.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004011
4012 @staticmethod
4013 def fullyConnectedOp(ser, input, filter):
4014 # input: N, IC
4015 # filter: OC, IC
4016 # output: N, OC
4017
4018 output_shape = [input.shape[0], filter.shape[0]]
4019
Kevin Cheng3a478572021-01-22 17:21:02 -08004020 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004021 out_dtype = DType.INT32
4022 elif input.dtype == DType.INT16:
4023 out_dtype = DType.INT48
4024 elif input.dtype == DType.FLOAT:
4025 out_dtype = DType.FLOAT
4026 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004027 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004028
Kevin Cheng550ccc52021-03-03 11:21:43 -08004029 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004030
4031 @staticmethod
4032 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004033 # a: N, H, C
4034 # b: N, C, W
4035 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004036
Kevin Cheng2d60f002021-06-09 14:18:32 -07004037 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004038
Kevin Cheng3a478572021-01-22 17:21:02 -08004039 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004040 out_dtype = DType.INT32
4041 elif a.dtype == DType.INT16:
4042 out_dtype = DType.INT48
4043 elif a.dtype == DType.FLOAT:
4044 out_dtype = DType.FLOAT
4045 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004046 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004047
Kevin Cheng550ccc52021-03-03 11:21:43 -08004048 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004049
4050 @staticmethod
Matthew Haddon818ab902021-07-27 09:12:49 +01004051 def concatOp(ser, axis, *a):
4052 input1 = a[0]
4053 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004054
Matthew Haddon818ab902021-07-27 09:12:49 +01004055 output_shape = input1.shape.copy()
Eric Kunzee5e26762020-10-13 16:11:07 -07004056
Matthew Haddon818ab902021-07-27 09:12:49 +01004057 output_shape[axis] = input1.shape[axis]
4058
4059 for tensor in remaining_inputs:
4060 output_shape[axis] += tensor.shape[axis]
4061
4062 return ser.addOutput(output_shape, input1.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004063
4064 @staticmethod
4065 def padOp(ser, a, padding):
4066
4067 output_shape = a.shape.copy()
4068
4069 for i in range(len(output_shape)):
4070 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4071
Kevin Cheng550ccc52021-03-03 11:21:43 -08004072 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004073
4074 @staticmethod
4075 def reshapeOp(ser, a, shape):
4076 output_shape = shape.copy()
4077
4078 totalElements = 1
4079 for i in a.shape:
4080 totalElements *= i
4081
4082 # If there are any -1 elements, figure out what that dimension must be
4083 totalOutputElements = 1
4084 for i in output_shape:
4085 if i != -1:
4086 totalOutputElements *= i
4087
4088 # And fill it in
4089 for i in range(len(output_shape)):
4090 if output_shape[i] == -1:
4091 output_shape[i] = totalElements // totalOutputElements
4092
Kevin Cheng550ccc52021-03-03 11:21:43 -08004093 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004094
4095 @staticmethod
4096 def sliceOp(ser, a, begin, size):
4097
4098 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004099 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004100
4101 @staticmethod
4102 def tileOp(ser, a, multiples):
4103
4104 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004105 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004106
4107 for i in range(len(output_shape)):
4108 output_shape[i] = a.shape[i] * multiples[i]
4109
Kevin Cheng550ccc52021-03-03 11:21:43 -08004110 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004111
4112 @staticmethod
4113 def transposeOp(ser, a, perms):
4114 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004115 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004116
4117 for i in range(len(output_shape)):
4118 output_shape[i] = a.shape[perms[i]]
4119
Kevin Cheng550ccc52021-03-03 11:21:43 -08004120 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004121
4122 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08004123 def gatherOp(ser, values, indices):
4124 assert len(values.shape) == 3
4125 assert len(indices.shape) == 2
4126 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004127
Kevin Cheng77d0f762020-11-24 10:26:32 -08004128 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4129
Kevin Cheng550ccc52021-03-03 11:21:43 -08004130 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004131
4132 @staticmethod
4133 def scatterOp(ser, values_in, indices, input):
4134 assert len(values_in.shape) == 3
4135 assert len(indices.shape) == 2
4136 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004137 assert values_in.shape[0] == indices.shape[0] # N
4138 assert input.shape[1] == indices.shape[1] # W
4139 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004140
4141 output_shape = values_in.shape
4142
Kevin Cheng550ccc52021-03-03 11:21:43 -08004143 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004144
4145 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004146 def tableOp(ser, input, table_dtype):
4147 # Same shape as the input, but dtype dependent on table dtype
4148 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
4149 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
4150 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004151
4152 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004153 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004154 serializer,
4155 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004156 input,
4157 mode,
4158 stride,
4159 offset,
4160 shift,
4161 stride_fp,
4162 offset_fp,
4163 output_dims,
4164 input_dtype,
4165 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01004166 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08004167 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01004168 if error_name == ErrorIf.WrongRank:
4169 output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
4170 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004171 if error_name == ErrorIf.BatchMismatch:
4172 output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
4173 elif error_name == ErrorIf.ChannelMismatch:
4174 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
4175 else:
4176 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004177
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004178 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004179
4180 @staticmethod
4181 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004182 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004183
4184 @staticmethod
4185 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08004186 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004187 out_dtype = DType.INT32
4188 elif ifm.dtype == DType.INT16:
4189 out_dtype = DType.INT48
4190 elif ifm.dtype == DType.FLOAT:
4191 out_dtype = DType.FLOAT
4192 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004193 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004194
Kevin Cheng550ccc52021-03-03 11:21:43 -08004195 return ser.addOutput(output_shape, out_dtype)