Rename FLOAT type to FP32
Update tensor operations naming to state input type as TxT in
all cases. Effects CONV2D, CONV3D, DEPTHWISE_CONV2D,
FULLY_CONNECTED, TRANSPOSE_CONV2D.
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Ic959acfcb3aa0a910b33b774a5a85fac08219205
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index e0c6cf0..791fbf7 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -776,7 +776,7 @@
), "Op.MUL must have 2 placeholders, 0 consts"
tens = []
- if dtypeList[0] in (DType.FP16, DType.FLOAT):
+ if dtypeList[0] in (DType.FP16, DType.FP32):
tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
else:
placeholders = []
@@ -1106,10 +1106,10 @@
@staticmethod
def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
- if isinstance(dtypes, list) or isinstance(dtypes, tuple):
- input_dtype = dtypes[0]
- else:
- input_dtype = dtypes
+ assert isinstance(dtypes, list) or isinstance(
+ dtypes, tuple
+ ), f"{dtypes} unexpected"
+ input_dtype = dtypes[0]
if error_name == ErrorIf.WrongOutputType:
accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype)
@@ -1129,9 +1129,9 @@
elif dtype == DType.INT16:
accum_dtypes = [DType.INT48]
elif dtype == DType.FP16:
- accum_dtypes = [DType.FP16, DType.FLOAT]
- elif dtype == DType.FLOAT:
- accum_dtypes = [DType.FLOAT]
+ accum_dtypes = [DType.FP16, DType.FP32]
+ elif dtype == DType.FP32:
+ accum_dtypes = [DType.FP32]
elif error_name is None:
assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
@@ -1245,7 +1245,7 @@
if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
pad_const_int = testGen.getRandNumberDType(dtype)
pad_const_fp = 0
- elif dtype in (DType.FP16, DType.FLOAT):
+ elif dtype in (DType.FP16, DType.FP32):
pad_const_int = 0
pad_const_fp = testGen.getRandNumberDType(dtype)
else:
@@ -1303,9 +1303,9 @@
elif dtype == DType.INT8 or dtype == DType.INT16:
accum_dtypes = [DType.INT32]
elif dtype == DType.FP16:
- accum_dtypes = [DType.FP16, DType.FLOAT]
- elif dtype == DType.FLOAT:
- accum_dtypes = [DType.FLOAT]
+ accum_dtypes = [DType.FP16, DType.FP32]
+ elif dtype == DType.FP32:
+ accum_dtypes = [DType.FP32]
elif error_name is None:
assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
else:
@@ -1408,20 +1408,20 @@
if error_name == ErrorIf.WrongOutputType:
dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
elif inDtype == DType.INT8:
- dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
+ dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FP32]
elif inDtype == DType.INT16:
- dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
+ dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FP32]
elif inDtype == DType.INT32:
- dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
+ dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
elif inDtype == DType.BOOL:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FP16:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
- elif inDtype == DType.FLOAT:
+ elif inDtype == DType.FP32:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output type for incorrect input type
- dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
+ dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
else:
raise Exception("Unexpected input dtype: {}".format(inDtype))
@@ -1826,8 +1826,8 @@
outputDTypeList = [DType.INT48]
elif dtype == DType.FP16:
outputDTypeList = [DType.FP16]
- elif dtype == DType.FLOAT:
- outputDTypeList = [DType.FLOAT]
+ elif dtype == DType.FP32:
+ outputDTypeList = [DType.FP32]
elif error_name == ErrorIf.WrongInputType:
# If an incorrect input type is used then we set a 'correct'
# output type to avoid other errors