Refactor ref_model rank checking and add level check to argmax

Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: Iad035b31d5e5e83040068e6311501490765bfff7
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h
index 3433192..aafc07f 100644
--- a/reference_model/src/graph_node.h
+++ b/reference_model/src/graph_node.h
@@ -270,19 +270,15 @@
 
     int setRequiredRank(const int min, const int max = -1)
     {
-        if (max == -1)
-        {
-            requiredRankMin = requiredRankMax = min;
-        }
-        else
-        {
-            requiredRankMin = min;
-            requiredRankMax = max;
-        }
+        requiredRankMin = min;
+        requiredRankMax = max;
 
-        ASSERT_MSG(requiredRankMin <= requiredRankMax,
-                   "GraphNode::setRequiredRank: requiredRankMin %d must be <= requiredRankMax %d", requiredRankMin,
-                   requiredRankMax);
+        if (requiredRankMin >= 0 && requiredRankMax >= 0)
+        {
+            ASSERT_MSG(requiredRankMin <= requiredRankMax,
+                    "GraphNode::setRequiredRank: requiredRankMin %d must be <= requiredRankMax %d", requiredRankMin,
+                    requiredRankMax);
+        }
 
         return 0;
     }
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 442cef8..fd19f96 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -27,7 +27,7 @@
     : GraphNode(sgt_, Op_CONCAT, id_)
 {
     setRequiredOperands(-1, 1);
-    setRequiredRank(1, 6);
+    setRequiredRank(1);
 
     INIT_ATTRIBUTE(Axis);
 }
@@ -131,7 +131,7 @@
     : GraphNode(sgt_, Op_PAD, id_)
 {
     setRequiredOperands(1, 1);
-    setRequiredRank(1, 6);
+    setRequiredRank(1);
 
     INIT_ATTRIBUTE(Pad);
 }
@@ -221,7 +221,6 @@
     : GraphNode(sgt_, Op_RESHAPE, id_)
 {
     setRequiredOperands(1, 1);
-    setRequiredRank(0, 6);
 
     INIT_ATTRIBUTE(Reshape);
 }
@@ -244,11 +243,6 @@
     if (validateRequiredOperands())
         return 1;
 
-    if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
-    {
-        return 1;
-    }
-
     // output and input must be the same types
     if (inputs[0]->matchType(*outputs[0]))
     {
@@ -321,7 +315,7 @@
     : GraphNode(sgt_, Op_REVERSE, id_)
 {
     setRequiredOperands(1, 1);
-    setRequiredRank(1, 6);
+    setRequiredRank(1);
 
     INIT_ATTRIBUTE(Axis);
 }
@@ -392,7 +386,7 @@
     : GraphNode(sgt_, Op_SLICE, id_)
 {
     setRequiredOperands(1, 1);
-    setRequiredRank(1, 6);
+    setRequiredRank(1);
 
     INIT_ATTRIBUTE(Slice);
 }
@@ -465,7 +459,7 @@
     : GraphNode(sgt_, Op_TILE, id_)
 {
     setRequiredOperands(1, 1);
-    setRequiredRank(1, 6);
+    setRequiredRank(1);
 
     INIT_ATTRIBUTE(Tile);
 }
@@ -667,7 +661,7 @@
     : GraphNode(sgt_, Op_TRANSPOSE, id_)
 {
     setRequiredOperands(1, 1);
-    setRequiredRank(1, 6);
+    setRequiredRank(1);
 
     INIT_ATTRIBUTE(Transpose);
 }
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index c5801e7..1e873e7 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -29,7 +29,6 @@
     : GraphNode(sgt_, op_, id_)
 {
     setRequiredOperands(2, 1);
-    setRequiredRank(0, 6);
 
     a = b  = nullptr;
     result = nullptr;
@@ -51,11 +50,6 @@
     if (validateRequiredOperands())
         return 1;
 
-    if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
-    {
-        return 1;
-    }
-
     // A & B must be the same rank and types
     if (inputs[0]->matchRankType(*inputs[1]))
     {
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
index 090ce29..16554b5 100644
--- a/reference_model/src/ops/ewise_ternary.cc
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -26,7 +26,6 @@
     : GraphNode(sgt_, Op_SELECT, id_)
 {
     setRequiredOperands(3, 1);
-    setRequiredRank(0, 6);
 }
 
 template <int Rank, TOSA_REF_TYPE Dtype>
@@ -43,12 +42,6 @@
     if (validateRequiredOperands())
         return 1;
 
-    if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(inputs[2]) ||
-        validateRequiredRank(outputs[0]))
-    {
-        return 1;
-    }
-
     // output and input must be the same types
     if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */) ||
         inputs[1]->matchRankTypeShape(*outputs[0], true /* broadcastOk */) ||
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
index 514cb84..e6e870e 100644
--- a/reference_model/src/ops/ewise_unary.cc
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -27,7 +27,6 @@
     : GraphNode(sgt_, op_, id_)
 {
     setRequiredOperands(1, 1);
-    setRequiredRank(0, 6);
 
     fcn = [](InEigenType a) -> OutEigenType {
         ASSERT_MSG(0, "In default UnaryNode function, missing function registration");
@@ -49,11 +48,6 @@
     if (validateRequiredOperands())
         return 1;
 
-    if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
-    {
-        return 1;
-    }
-
     // output and input must be the same types
     if (inputs[0]->matchRankTypeShape(*outputs[0]))
     {
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index ca12cfe..575a500 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -113,11 +113,6 @@
     int16_t border_y = border[0];
     int16_t border_x = border[1];
 
-    // Check Tosa Level
-    auto tosa_level = g_func_config.tosa_level;
-    LEVEL_CHECK(scale_y_n / scale_y_d <= tosa_level.MAX_SCALE, "scale_y_n / scale_y_d should be smaller than or equal to MAX_SCALE");
-    LEVEL_CHECK(scale_x_n / scale_x_d <= tosa_level.MAX_SCALE, "scale_x_n / scale_x_d should be smaller than or equal to MAX_SCALE");
-
     ERROR_IF(std::max<int>({ in_height, in_width, out_height, out_width }) >= 16384,
              "OpResize: exceeds maximum dimension");
     ERROR_IF(in_batch != out_batch, "OpResize: output tensor batch mismatch");
@@ -137,6 +132,11 @@
     ERROR_IF((border_x < -16 * scale_x_n || border_x >= scale_x_n),
              "OpResize: invalid attribute border width dimension");
 
+    // Check Tosa Level
+    auto tosa_level = g_func_config.tosa_level;
+    LEVEL_CHECK(scale_y_n / scale_y_d <= tosa_level.MAX_SCALE, "scale_y_n / scale_y_d should be smaller than or equal to MAX_SCALE");
+    LEVEL_CHECK(scale_x_n / scale_x_d <= tosa_level.MAX_SCALE, "scale_x_n / scale_x_d should be smaller than or equal to MAX_SCALE");
+
     int32_t res_height = 0;
     int32_t res_width = 0;
 
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
index bf8ba57..fd48472 100644
--- a/reference_model/src/ops/reduction.cc
+++ b/reference_model/src/ops/reduction.cc
@@ -25,7 +25,7 @@
     : GraphNode(sgt_, op_, id_)
 {
     setRequiredOperands(1, 1);
-    setRequiredRank(0, 4);
+    setRequiredRank(1, 4);
 
     INIT_ATTRIBUTE(Axis);
 }
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index f8fd323..a60819d 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -327,7 +327,7 @@
     : GraphNode(sgt_, Op_ARGMAX, id_)
 {
     setRequiredOperands(1, 1);
-    setRequiredRank(1, 4);
+    setRequiredRank(1);
 
     INIT_ATTRIBUTE(Axis);
 }
@@ -405,6 +405,10 @@
 template <int Rank, TOSA_REF_TYPE Dtype>
 int OpArgMax<Rank, Dtype>::eval()
 {
+    // Check Tosa Level
+    auto tosa_level = g_func_config.tosa_level;
+    LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
+
     Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
 
     this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; });
@@ -419,7 +423,7 @@
     : GraphNode(sgt_, Op_AVG_POOL2D, id_)
 {
     setRequiredOperands(1, 1);
-    setRequiredRank(4);
+    setRequiredRank(4, 4);
 
     INIT_ATTRIBUTE(Pool);
 }
@@ -645,7 +649,7 @@
     : GraphNode(sgt_, Op_CONV2D, id_)
 {
     setRequiredOperands(3, 1);
-    setRequiredRank(4);
+    setRequiredRank(4, 4);
 
     INIT_ATTRIBUTE(Conv);
 }
@@ -839,7 +843,7 @@
     : GraphNode(sgt_, Op_CONV3D, id_)
 {
     setRequiredOperands(3, 1);
-    setRequiredRank(5);
+    setRequiredRank(5, 5);
 
     INIT_ATTRIBUTE(Conv);
 }
@@ -1042,7 +1046,7 @@
     : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
 {
     setRequiredOperands(3, 1);
-    setRequiredRank(4);
+    setRequiredRank(4, 4);
 
     INIT_ATTRIBUTE(Conv);
 }
@@ -1227,7 +1231,7 @@
     : GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
 {
     setRequiredOperands(3, 1);
-    setRequiredRank(2);
+    setRequiredRank(2, 2);
 
     INIT_ATTRIBUTE(FullyConnected);
 }
@@ -1322,7 +1326,7 @@
     : GraphNode(sgt_, Op_MATMUL, id_)
 {
     setRequiredOperands(2, 1);
-    setRequiredRank(3);
+    setRequiredRank(3, 3);
 
     INIT_ATTRIBUTE(MatMul);
 }
@@ -1460,7 +1464,7 @@
     : GraphNode(sgt_, Op_MAX_POOL2D, id_)
 {
     setRequiredOperands(1, 1);
-    setRequiredRank(4);
+    setRequiredRank(4, 4);
 
     INIT_ATTRIBUTE(Pool);
 }
@@ -1601,7 +1605,7 @@
     : GraphNode(sgt_, Op_FFT2D, id_)
 {
     setRequiredOperands(2, 2);
-    setRequiredRank(3);
+    setRequiredRank(3, 3);
 
     INIT_ATTRIBUTE(FFT);
 }
@@ -1724,7 +1728,7 @@
     : GraphNode(sgt_, Op_RFFT2D, id_)
 {
     setRequiredOperands(1, 2);
-    setRequiredRank(3);
+    setRequiredRank(3, 3);
 }
 
 template <TOSA_REF_TYPE Dtype>
@@ -1830,7 +1834,7 @@
     : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
 {
     setRequiredOperands(3, 1);
-    setRequiredRank(4);
+    setRequiredRank(4, 4);
 
     INIT_ATTRIBUTE(TransposeConv);
 }
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 68ffb1f..fce8e7c 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -31,7 +31,6 @@
     : GraphNode(sgt_, Op_RESCALE, id_)
 {
     setRequiredOperands(1, 1);
-    setRequiredRank(0, 6);
     INIT_ATTRIBUTE(Rescale);
 }
 
@@ -52,11 +51,6 @@
     if (validateRequiredOperands())
         return 1;
 
-    if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
-    {
-        return 1;
-    }
-
     // output and input must be the same rank and size
     if (inputs[0]->matchRankSize(*outputs[0]))
     {
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 08ee8bf..b68a9b6 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -197,9 +197,9 @@
     }
 
     // Unary check to make sure rank matches
-    const int checkRequiredRank(const int exactRank) const
+    const int checkRequiredRank(const int minRank) const
     {
-        return (shape.size() == (size_t)exactRank) ? 0 : 1;
+        return (shape.size() >= (size_t)minRank) ? 0 : 1;
     }
 
     const int checkRequiredRank(const int minRank, const int maxRank) const
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 65bdeb7..c8c22c2 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -2614,7 +2614,7 @@
         "argmax": {
             "op": Op.ARGMAX,
             "operands": (1, 0),
-            "rank": (1, 4),
+            "rank": (1, 6),
             "build_fcn": (
                 build_argmax,
                 TosaTensorGen.tgBasic,
diff --git a/verif/runner/tosa_refmodel_sut_run.py b/verif/runner/tosa_refmodel_sut_run.py
index df5c0db..7b129da 100644
--- a/verif/runner/tosa_refmodel_sut_run.py
+++ b/verif/runner/tosa_refmodel_sut_run.py
@@ -34,6 +34,7 @@
         # Call Reference model with description file to provide all file details
         cmd = [
             args.ref_model_path,
+            "--tosa_level={}".format(args.tosa_level),
             "--operator_fbs={}".format(args.operator_fbs),
             "--test_desc={}".format(self.descFile),
         ]
diff --git a/verif/tests/test_tosa_refmodel.py b/verif/tests/test_tosa_refmodel.py
index 1f9cd3e..79e6720 100644
--- a/verif/tests/test_tosa_refmodel.py
+++ b/verif/tests/test_tosa_refmodel.py
@@ -37,6 +37,7 @@
 OUTPUT_CONST_GLOB = "const-*.npy"
 
 TEST_DESC_FILENAME = "desc.json"
+TOSA_LEVEL = "EIGHTK"
 
 # Conversion from refmodel type into the type abbreviation used in the test output
 REF_MODEL_TYPE_TO_OUT = {
@@ -182,6 +183,8 @@
             str(desc_file),
             "--ofm_file",
             OUTPUT_OFM_FILE,
+            "--tosa_level",
+            TOSA_LEVEL,
         ]
         try:
             run_sh_command(refmodel_cmd, verbose=True, capture_output=True)