Refactor ref_model rank checking and add level check to argmax
Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: Iad035b31d5e5e83040068e6311501490765bfff7
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h
index 3433192..aafc07f 100644
--- a/reference_model/src/graph_node.h
+++ b/reference_model/src/graph_node.h
@@ -270,19 +270,15 @@
int setRequiredRank(const int min, const int max = -1)
{
- if (max == -1)
- {
- requiredRankMin = requiredRankMax = min;
- }
- else
- {
- requiredRankMin = min;
- requiredRankMax = max;
- }
+ requiredRankMin = min;
+ requiredRankMax = max;
- ASSERT_MSG(requiredRankMin <= requiredRankMax,
- "GraphNode::setRequiredRank: requiredRankMin %d must be <= requiredRankMax %d", requiredRankMin,
- requiredRankMax);
+ if (requiredRankMin >= 0 && requiredRankMax >= 0)
+ {
+ ASSERT_MSG(requiredRankMin <= requiredRankMax,
+ "GraphNode::setRequiredRank: requiredRankMin %d must be <= requiredRankMax %d", requiredRankMin,
+ requiredRankMax);
+ }
return 0;
}
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 442cef8..fd19f96 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -27,7 +27,7 @@
: GraphNode(sgt_, Op_CONCAT, id_)
{
setRequiredOperands(-1, 1);
- setRequiredRank(1, 6);
+ setRequiredRank(1);
INIT_ATTRIBUTE(Axis);
}
@@ -131,7 +131,7 @@
: GraphNode(sgt_, Op_PAD, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(1, 6);
+ setRequiredRank(1);
INIT_ATTRIBUTE(Pad);
}
@@ -221,7 +221,6 @@
: GraphNode(sgt_, Op_RESHAPE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
INIT_ATTRIBUTE(Reshape);
}
@@ -244,11 +243,6 @@
if (validateRequiredOperands())
return 1;
- if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
- {
- return 1;
- }
-
// output and input must be the same types
if (inputs[0]->matchType(*outputs[0]))
{
@@ -321,7 +315,7 @@
: GraphNode(sgt_, Op_REVERSE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(1, 6);
+ setRequiredRank(1);
INIT_ATTRIBUTE(Axis);
}
@@ -392,7 +386,7 @@
: GraphNode(sgt_, Op_SLICE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(1, 6);
+ setRequiredRank(1);
INIT_ATTRIBUTE(Slice);
}
@@ -465,7 +459,7 @@
: GraphNode(sgt_, Op_TILE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(1, 6);
+ setRequiredRank(1);
INIT_ATTRIBUTE(Tile);
}
@@ -667,7 +661,7 @@
: GraphNode(sgt_, Op_TRANSPOSE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(1, 6);
+ setRequiredRank(1);
INIT_ATTRIBUTE(Transpose);
}
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index c5801e7..1e873e7 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -29,7 +29,6 @@
: GraphNode(sgt_, op_, id_)
{
setRequiredOperands(2, 1);
- setRequiredRank(0, 6);
a = b = nullptr;
result = nullptr;
@@ -51,11 +50,6 @@
if (validateRequiredOperands())
return 1;
- if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
- {
- return 1;
- }
-
// A & B must be the same rank and types
if (inputs[0]->matchRankType(*inputs[1]))
{
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
index 090ce29..16554b5 100644
--- a/reference_model/src/ops/ewise_ternary.cc
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -26,7 +26,6 @@
: GraphNode(sgt_, Op_SELECT, id_)
{
setRequiredOperands(3, 1);
- setRequiredRank(0, 6);
}
template <int Rank, TOSA_REF_TYPE Dtype>
@@ -43,12 +42,6 @@
if (validateRequiredOperands())
return 1;
- if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(inputs[2]) ||
- validateRequiredRank(outputs[0]))
- {
- return 1;
- }
-
// output and input must be the same types
if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */) ||
inputs[1]->matchRankTypeShape(*outputs[0], true /* broadcastOk */) ||
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
index 514cb84..e6e870e 100644
--- a/reference_model/src/ops/ewise_unary.cc
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -27,7 +27,6 @@
: GraphNode(sgt_, op_, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
fcn = [](InEigenType a) -> OutEigenType {
ASSERT_MSG(0, "In default UnaryNode function, missing function registration");
@@ -49,11 +48,6 @@
if (validateRequiredOperands())
return 1;
- if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
- {
- return 1;
- }
-
// output and input must be the same types
if (inputs[0]->matchRankTypeShape(*outputs[0]))
{
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index ca12cfe..575a500 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -113,11 +113,6 @@
int16_t border_y = border[0];
int16_t border_x = border[1];
- // Check Tosa Level
- auto tosa_level = g_func_config.tosa_level;
- LEVEL_CHECK(scale_y_n / scale_y_d <= tosa_level.MAX_SCALE, "scale_y_n / scale_y_d should be smaller than or equal to MAX_SCALE");
- LEVEL_CHECK(scale_x_n / scale_x_d <= tosa_level.MAX_SCALE, "scale_x_n / scale_x_d should be smaller than or equal to MAX_SCALE");
-
ERROR_IF(std::max<int>({ in_height, in_width, out_height, out_width }) >= 16384,
"OpResize: exceeds maximum dimension");
ERROR_IF(in_batch != out_batch, "OpResize: output tensor batch mismatch");
@@ -137,6 +132,11 @@
ERROR_IF((border_x < -16 * scale_x_n || border_x >= scale_x_n),
"OpResize: invalid attribute border width dimension");
+ // Check Tosa Level
+ auto tosa_level = g_func_config.tosa_level;
+ LEVEL_CHECK(scale_y_n / scale_y_d <= tosa_level.MAX_SCALE, "scale_y_n / scale_y_d should be smaller than or equal to MAX_SCALE");
+ LEVEL_CHECK(scale_x_n / scale_x_d <= tosa_level.MAX_SCALE, "scale_x_n / scale_x_d should be smaller than or equal to MAX_SCALE");
+
int32_t res_height = 0;
int32_t res_width = 0;
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
index bf8ba57..fd48472 100644
--- a/reference_model/src/ops/reduction.cc
+++ b/reference_model/src/ops/reduction.cc
@@ -25,7 +25,7 @@
: GraphNode(sgt_, op_, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 4);
+ setRequiredRank(1, 4);
INIT_ATTRIBUTE(Axis);
}
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index f8fd323..a60819d 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -327,7 +327,7 @@
: GraphNode(sgt_, Op_ARGMAX, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(1, 4);
+ setRequiredRank(1);
INIT_ATTRIBUTE(Axis);
}
@@ -405,6 +405,10 @@
template <int Rank, TOSA_REF_TYPE Dtype>
int OpArgMax<Rank, Dtype>::eval()
{
+ // Check Tosa Level
+ auto tosa_level = g_func_config.tosa_level;
+ LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
+
Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; });
@@ -419,7 +423,7 @@
: GraphNode(sgt_, Op_AVG_POOL2D, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(4);
+ setRequiredRank(4, 4);
INIT_ATTRIBUTE(Pool);
}
@@ -645,7 +649,7 @@
: GraphNode(sgt_, Op_CONV2D, id_)
{
setRequiredOperands(3, 1);
- setRequiredRank(4);
+ setRequiredRank(4, 4);
INIT_ATTRIBUTE(Conv);
}
@@ -839,7 +843,7 @@
: GraphNode(sgt_, Op_CONV3D, id_)
{
setRequiredOperands(3, 1);
- setRequiredRank(5);
+ setRequiredRank(5, 5);
INIT_ATTRIBUTE(Conv);
}
@@ -1042,7 +1046,7 @@
: GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
{
setRequiredOperands(3, 1);
- setRequiredRank(4);
+ setRequiredRank(4, 4);
INIT_ATTRIBUTE(Conv);
}
@@ -1227,7 +1231,7 @@
: GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
{
setRequiredOperands(3, 1);
- setRequiredRank(2);
+ setRequiredRank(2, 2);
INIT_ATTRIBUTE(FullyConnected);
}
@@ -1322,7 +1326,7 @@
: GraphNode(sgt_, Op_MATMUL, id_)
{
setRequiredOperands(2, 1);
- setRequiredRank(3);
+ setRequiredRank(3, 3);
INIT_ATTRIBUTE(MatMul);
}
@@ -1460,7 +1464,7 @@
: GraphNode(sgt_, Op_MAX_POOL2D, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(4);
+ setRequiredRank(4, 4);
INIT_ATTRIBUTE(Pool);
}
@@ -1601,7 +1605,7 @@
: GraphNode(sgt_, Op_FFT2D, id_)
{
setRequiredOperands(2, 2);
- setRequiredRank(3);
+ setRequiredRank(3, 3);
INIT_ATTRIBUTE(FFT);
}
@@ -1724,7 +1728,7 @@
: GraphNode(sgt_, Op_RFFT2D, id_)
{
setRequiredOperands(1, 2);
- setRequiredRank(3);
+ setRequiredRank(3, 3);
}
template <TOSA_REF_TYPE Dtype>
@@ -1830,7 +1834,7 @@
: GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
{
setRequiredOperands(3, 1);
- setRequiredRank(4);
+ setRequiredRank(4, 4);
INIT_ATTRIBUTE(TransposeConv);
}
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 68ffb1f..fce8e7c 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -31,7 +31,6 @@
: GraphNode(sgt_, Op_RESCALE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
INIT_ATTRIBUTE(Rescale);
}
@@ -52,11 +51,6 @@
if (validateRequiredOperands())
return 1;
- if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
- {
- return 1;
- }
-
// output and input must be the same rank and size
if (inputs[0]->matchRankSize(*outputs[0]))
{
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 08ee8bf..b68a9b6 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -197,9 +197,9 @@
}
// Unary check to make sure rank matches
- const int checkRequiredRank(const int exactRank) const
+ const int checkRequiredRank(const int minRank) const
{
- return (shape.size() == (size_t)exactRank) ? 0 : 1;
+ return (shape.size() >= (size_t)minRank) ? 0 : 1;
}
const int checkRequiredRank(const int minRank, const int maxRank) const
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 65bdeb7..c8c22c2 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -2614,7 +2614,7 @@
"argmax": {
"op": Op.ARGMAX,
"operands": (1, 0),
- "rank": (1, 4),
+ "rank": (1, 6),
"build_fcn": (
build_argmax,
TosaTensorGen.tgBasic,
diff --git a/verif/runner/tosa_refmodel_sut_run.py b/verif/runner/tosa_refmodel_sut_run.py
index df5c0db..7b129da 100644
--- a/verif/runner/tosa_refmodel_sut_run.py
+++ b/verif/runner/tosa_refmodel_sut_run.py
@@ -34,6 +34,7 @@
# Call Reference model with description file to provide all file details
cmd = [
args.ref_model_path,
+ "--tosa_level={}".format(args.tosa_level),
"--operator_fbs={}".format(args.operator_fbs),
"--test_desc={}".format(self.descFile),
]
diff --git a/verif/tests/test_tosa_refmodel.py b/verif/tests/test_tosa_refmodel.py
index 1f9cd3e..79e6720 100644
--- a/verif/tests/test_tosa_refmodel.py
+++ b/verif/tests/test_tosa_refmodel.py
@@ -37,6 +37,7 @@
OUTPUT_CONST_GLOB = "const-*.npy"
TEST_DESC_FILENAME = "desc.json"
+TOSA_LEVEL = "EIGHTK"
# Conversion from refmodel type into the type abbreviation used in the test output
REF_MODEL_TYPE_TO_OUT = {
@@ -182,6 +183,8 @@
str(desc_file),
"--ofm_file",
OUTPUT_OFM_FILE,
+ "--tosa_level",
+ TOSA_LEVEL,
]
try:
run_sh_command(refmodel_cmd, verbose=True, capture_output=True)