Add DIM operator to reference model

Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: Iea11ee5d3d98773e9c5e9b827593c05afb41ce3b
diff --git a/reference_model/include/dtype.h b/reference_model/include/dtype.h
index bde678a..1b01a0e 100644
--- a/reference_model/include/dtype.h
+++ b/reference_model/include/dtype.h
@@ -40,6 +40,7 @@
     TOSA_REF_TYPE_UINT16  = 9,
     TOSA_REF_TYPE_FP16    = 10,
     TOSA_REF_TYPE_BF16    = 11,
+    TOSA_REF_TYPE_SHAPE   = 12,
     TOSA_REF_TYPE_FP64    = 99,    // FP64 is special: add new data types above
 };
 
@@ -71,6 +72,8 @@
             return EnumNameDType(DType_FP16);
         case TOSA_REF_TYPE_BF16:
             return EnumNameDType(DType_BF16);
+        case TOSA_REF_TYPE_SHAPE:
+            return EnumNameDType(DType_SHAPE);
         case TOSA_REF_TYPE_FP64:
             return "FP64";
         default:
@@ -82,7 +85,7 @@
 // return corresponding TOSA_REF_TYPE for DType
 inline TOSA_REF_TYPE ConvertDType(const DType dtype)
 {
-    assert(DType_MAX == DType_BF16);    // must update whenever DType_MAX changes
+    assert(DType_MAX == DType_SHAPE);    // must update whenever DType_MAX changes
 
     if (g_func_config.precise_mode)
     {
@@ -122,6 +125,8 @@
             return TOSA_REF_TYPE_FP16;
         case DType_BF16:
             return TOSA_REF_TYPE_BF16;
+        case DType_SHAPE:
+            return TOSA_REF_TYPE_SHAPE;
         default:
             break;
     }
diff --git a/reference_model/include/operators.h b/reference_model/include/operators.h
index 6b4ec58..b0a3227 100644
--- a/reference_model/include/operators.h
+++ b/reference_model/include/operators.h
@@ -57,6 +57,7 @@
         tosa_datatype_int8_t   = 8,
         tosa_datatype_uint16_t = 9,
         tosa_datatype_uint8_t  = 10,
+        tosa_datatype_shape_t  = 11,
     };
 
     struct tosa_tensor_t
@@ -275,6 +276,8 @@
                                const float client_pad_const_fp,
                                tosa_tensor_t client_output);
 
+    tosa_status_t tosa_run_dim(tosa_tensor_t client_input1, const int32_t client_axis, tosa_tensor_t client_output);
+
     tosa_status_t tosa_run_reshape(tosa_tensor_t client_input1,
                                    const int32_t client_new_shape_len,
                                    const int32_t client_new_shape[],
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index 1ae0683..ae5963d 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -68,6 +68,8 @@
             return tosa::DType::DType_UINT16;
         case tosa_datatype_uint8_t:
             return tosa::DType::DType_UINT8;
+        case tosa_datatype_shape_t:
+            return tosa::DType::DType_SHAPE;
         default:
             return tosa::DType::DType_UNKNOWN;
     }
@@ -1978,6 +1980,34 @@
         return tosa_status_valid;
     }
 
+    tosa_status_t tosa_run_dim(tosa_tensor_t client_input1, const int32_t client_axis, tosa_tensor_t client_output)
+    {
+        // Create operator attributes
+        TosaAxisAttribute attr(client_axis);
+
+        // Create tensors
+        tosa::TosaSerializationTensor* input1 = translate_client_tensor(client_input1, "input1");
+        tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output");
+
+        // Create operator
+        auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_DIM, tosa::Attribute::Attribute_AxisAttribute, &attr,
+                                                      { input1->GetName() }, { output->GetName() });
+
+        // Create a tosa single-op basic block
+        tosa::TosaSerializationBasicBlock block("dim", "main", { op }, { input1, output }, { input1->GetName() },
+                                                { output->GetName() });
+
+        // Setup model
+        TosaReference::ModelRunnerImpl runner;
+        TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
+        TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
+
+        // Execute
+        TOSA_RETURN_ON_ERROR(runner.getOutput(output->GetName(), client_output.data, client_output.size));
+
+        return tosa_status_valid;
+    }
+
     tosa_status_t tosa_run_reshape(tosa_tensor_t client_input1,
                                    const int32_t client_new_shape_len,
                                    const int32_t client_new_shape[],
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index bc97c89..2d1fdb0 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -209,6 +209,60 @@
     return GraphNode::eval();
 }
 
+template <int Rank, TOSA_REF_TYPE Dtype>
+OpDim<Rank, Dtype>::OpDim(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
+    : GraphNode(sgt_, Op_DIM, id_)
+{
+    setRequiredOperands(1, 1);
+
+    INIT_ATTRIBUTE(Axis);
+}
+
+template <int Rank, TOSA_REF_TYPE Dtype>
+OpDim<Rank, Dtype>::~OpDim()
+{
+    if (attribute)
+        delete attribute;
+}
+
+template <int Rank, TOSA_REF_TYPE Dtype>
+int OpDim<Rank, Dtype>::checkTensorAttributes()
+{
+    // Check Tosa Level
+    auto tosa_level = g_func_config.tosa_level;
+    LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
+
+    if (validateRequiredOperands())
+        return 1;
+
+    if (validateRequiredRank(inputs[0]))
+        return 1;
+
+    if (attribute->axis() < 0 || (size_t)attribute->axis() >= Rank)
+    {
+        printNodeValidationError("OpDim: axis must between [0, input_rank - 1]");
+        return 1;
+    }
+
+    in  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+    out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+    ASSERT_MEM(in && out);
+
+    return 0;
+}
+
+template <int Rank, TOSA_REF_TYPE Dtype>
+int OpDim<Rank, Dtype>::eval()
+{
+    int32_t axis    = attribute->axis();
+    int64_t out_val = in->getShape()[axis];
+
+    this->out->getTensor().setConstant(out_val);
+
+    return GraphNode::eval();
+}
+
 template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
 OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
     : GraphNode(sgt_, Op_RESHAPE, id_)
@@ -780,6 +834,14 @@
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64);
 
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BF16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BOOL);
+
 DEF_INSTANTIATE_RESHAPE(OpReshape, FP16);
 DEF_INSTANTIATE_RESHAPE(OpReshape, BF16);
 DEF_INSTANTIATE_RESHAPE(OpReshape, FP32);
diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h
index 94ce248..024f9a2 100644
--- a/reference_model/src/ops/data_layout.h
+++ b/reference_model/src/ops/data_layout.h
@@ -66,6 +66,27 @@
     TosaPadAttribute* attribute;
 };
 
+template <int Rank, TOSA_REF_TYPE Dtype>
+class OpDim : public GraphNode
+{
+public:
+    OpDim(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
+    virtual ~OpDim();
+
+    virtual int checkTensorAttributes();
+    virtual int eval();
+
+    using InEigenType  = typename GetEigenType<Dtype>::type;
+    using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type;
+    using TIn          = Eigen::Tensor<InEigenType, Rank>;
+    using TOut         = Eigen::Tensor<OutEigenType, 0>;
+
+protected:
+    TosaReference::TensorTemplate<TIn>* in;
+    TosaReference::TensorTemplate<TOut>* out;
+    TosaAxisAttribute* attribute;
+};
+
 template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
 class OpReshape : public GraphNode
 {
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index a3069dc..d834b74 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -419,6 +419,15 @@
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64);
             break;
+        case Op_DIM:
+            DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP16);
+            DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BF16);
+            DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP32);
+            DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT32);
+            DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT8);
+            DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT16);
+            DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BOOL);
+            break;
         case Op_RESHAPE:
             DEF_FACTORY_RESHAPE(OpReshape, FP16);
             DEF_FACTORY_RESHAPE(OpReshape, BF16);
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h
index 6dd6e76..342d5c2 100644
--- a/reference_model/src/ops/template_types.h
+++ b/reference_model/src/ops/template_types.h
@@ -98,6 +98,11 @@
     using type = int64_t;
 };
 template <>
+struct GetEigenType<TOSA_REF_TYPE_SHAPE>
+{
+    using type = int64_t;
+};
+template <>
 struct GetEigenType<TOSA_REF_TYPE_BOOL>
 {
     using type = bool;
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 5675be9..186cb8b 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -478,7 +478,8 @@
                 tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
             }
             break;
-            case DType_INT48: {
+            case DType_INT48:
+            case DType_SHAPE: {
                 std::vector<int64_t> i64_data;
                 TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data);
                 tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 4982c99..1aabe5b 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -137,6 +137,7 @@
             nperror = NumpyUtilities::readFromNpyFile(filename, elements, i32databuf);
             break;
         case DType_INT48:
+        case DType_SHAPE:
             i64databuf = (int64_t*)calloc(sizeof(int64_t), elements);
             ASSERT_MEM(i64databuf);
 
@@ -220,6 +221,7 @@
             }
             break;
         case TOSA_REF_TYPE_INT48:
+        case TOSA_REF_TYPE_SHAPE:
             if (setTensorValueInt64(elements, i64databuf))
             {
                 free(i64databuf);
@@ -379,6 +381,7 @@
             free(i32databuf);
             break;
         case TOSA_REF_TYPE_INT48:
+        case TOSA_REF_TYPE_SHAPE:
             i64databuf = (int64_t*)calloc(sizeof(int64_t), elements);
             ASSERT_MEM(i64databuf);
 
@@ -672,6 +675,7 @@
     switch (getDtype())
     {
         case TOSA_REF_TYPE_INT48:
+        case TOSA_REF_TYPE_SHAPE:
             if (vals.size() != elements)
             {
                 WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
@@ -847,6 +851,7 @@
     switch (getDtype())
     {
         case TOSA_REF_TYPE_INT48:
+        case TOSA_REF_TYPE_SHAPE:
             if (vals.size() != elements)
             {
                 WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index f59a5e1..74f57ed 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -811,6 +811,13 @@
                         return new Tensor6<int64_t>(tensorName_, dtype_, shape_);
                 }
                 break;
+            case TOSA_REF_TYPE_SHAPE:
+                // if shape information is not already set, set it here.
+                if (shape_.size() == 0)
+                {
+                    shape_ = { 1 };
+                }
+                return new Tensor0<int64_t>(tensorName_, dtype_, shape_);
             case TOSA_REF_TYPE_BOOL:
                 switch (rank)
                 {
diff --git a/verif/conformance/test_select.py b/verif/conformance/test_select.py
index f4e2a61..b7bbfc3 100644
--- a/verif/conformance/test_select.py
+++ b/verif/conformance/test_select.py
@@ -506,6 +506,13 @@
     param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
 
 
+class DimOeprator(Operator):
+    """Test selector for the DIM operator."""
+
+    name = "dim"
+    param_names = ["shape", "type", "axis"]
+
+
 class EqualOperator(Operator):
     """Test selector for the EQUAL operator."""
 
diff --git a/verif/conformance/tosa_base_profile_ops_info.json b/verif/conformance/tosa_base_profile_ops_info.json
index 4e3cd03..772602b 100644
--- a/verif/conformance/tosa_base_profile_ops_info.json
+++ b/verif/conformance/tosa_base_profile_ops_info.json
@@ -1317,6 +1317,60 @@
             }
         }
     },
+    "dim": {
+        "group": "data_layout",
+        "profile": [
+            "tosa-bi",
+            "tosa-mi"
+        ],
+        "generation": {
+            "standard": {
+                "generator_args": [
+                    [
+                        "--target-dtype",
+                        "int8",
+                        "--target-dtype",
+                        "int16",
+                        "--target-dtype",
+                        "int32",
+                        "--target-dtype",
+                        "bool",
+                        "--tensor-dim-range",
+                        "1,64",
+                        "--target-rank",
+                        "1",
+                        "--target-rank",
+                        "2",
+                        "--target-rank",
+                        "3"
+                    ]
+                ]
+            },
+            "8k_level": {
+                "no_negative_tests": "true",
+                "selector": "8k_level",
+                "generator_args": [
+                    [
+                        "--target-dtype",
+                        "int8",
+                        "--tensor-dim-range",
+                        "1,10",
+                        "--target-rank",
+                        "6"
+                    ]
+                ]
+            }
+        },
+        "selection": {
+            "default": {
+                "params": {},
+                "permutes": [
+                    "shape",
+                    "type"
+                ]
+            }
+        }
+    },
     "equal": {
         "group": "comparison",
         "profile": [
diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json
index 07b6af3..7388835 100644
--- a/verif/conformance/tosa_main_profile_ops_info.json
+++ b/verif/conformance/tosa_main_profile_ops_info.json
@@ -787,6 +787,45 @@
             }
         }
     },
+    "dim": {
+        "group": "data_layout",
+        "profile": [
+            "tosa-mi"
+        ],
+        "generation": {
+            "standard": {
+                "generator_args": [
+                    [
+                        "--target-dtype",
+                        "fp32",
+                        "--target-dtype",
+                        "fp16",
+                        "--target-dtype",
+                        "bf16",
+                        "--fp-values-range",
+                        "-2.0,2.0",
+                        "--tensor-dim-range",
+                        "1,65",
+                        "--target-rank",
+                        "1",
+                        "--target-rank",
+                        "2",
+                        "--target-rank",
+                        "3"
+                    ]
+                ]
+            }
+        },
+        "selection": {
+            "default": {
+                "params": {},
+                "permutes": [
+                    "shape",
+                    "type"
+                ]
+            }
+        }
+    },
     "equal": {
         "group": "comparison",
         "profile": [
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index b5e71ac..8c18e67 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -88,7 +88,9 @@
             return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
         elif dtype == DType.UINT16:
             return np.int32(self.rng.integers(low=0, high=65536, size=shape))
-        elif dtype == DType.INT32:
+        elif (
+            dtype == DType.INT32 or dtype == DType.SHAPE
+        ):  # restricting too large value for SHAPE
             return np.int32(
                 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
             )
@@ -181,7 +183,9 @@
             low, high = (-128, 128)
         elif dtype == DType.INT16:
             low, high = (-32768, 32768)
-        elif dtype == DType.INT32:
+        elif (
+            dtype == DType.INT32 or dtype == DType.SHAPE
+        ):  # restricting too large value for SHAPE
             low, high = (-(1 << 31), (1 << 31))
         elif dtype == DType.INT48:
             low, high = (-(1 << 47), (1 << 47))
@@ -1310,6 +1314,49 @@
         self.ser.addOperator(op["op"], input_list, output_list, attr)
         return result_tens
 
+    def build_dim(
+        self,
+        op,
+        a,
+        axis,
+        validator_fcns=None,
+        error_name=None,
+        qinfo=None,
+    ):
+        result_tens = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
+
+        # Invalidate Input/Output list for error if checks.
+        input_list = [a.name]
+        output_list = [result_tens.name]
+        pCount, cCount = op["operands"]
+        num_operands = pCount + cCount
+        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+            self, error_name, input_list, output_list
+        )
+
+        if not TosaErrorValidator.evValidateErrorIfs(
+            self.ser,
+            validator_fcns,
+            error_name,
+            op=op,
+            axis=axis,
+            input_shape=a.shape,
+            input_dtype=a.dtype,
+            output_shape=result_tens.shape,
+            output_dtype=result_tens.dtype,
+            result_tensors=[result_tens],
+            input_list=input_list,
+            output_list=output_list,
+            num_operands=num_operands,
+        ):
+            return None
+
+        attr = ts.TosaSerializerAttribute()
+        attr.AxisAttribute(axis)
+
+        self.ser.addOperator(op["op"], input_list, output_list, attr)
+        return result_tens
+
     def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
         result_tens = OutputShaper.reshapeOp(
             self.ser, self.rng, a, newShape, error_name
@@ -3749,6 +3796,25 @@
                 TosaErrorValidator.evWrongRank,
             ),
         },
+        "dim": {
+            "op": Op.DIM,
+            "operands": (1, 0),
+            "build_fcn": (
+                build_dim,
+                TosaTensorGen.tgBasic,
+                TosaTensorValuesGen.tvgDefault,
+                TosaArgGen.agAxis,
+            ),
+            "types": TYPE_FIB,
+            "error_if_validators": (
+                TosaErrorValidator.evAxisLargerRank,
+                TosaErrorValidator.evAxisSmallerZero,
+                TosaErrorValidator.evWrongInputType,
+                TosaErrorValidator.evWrongInputList,
+                TosaErrorValidator.evWrongOutputList,
+                TosaErrorValidator.evWrongRank,
+            ),
+        },
         "reshape": {
             "op": Op.RESHAPE,
             "operands": (1, 0),
@@ -4665,6 +4731,27 @@
         return ser.addOutput(output_shape, outputDType)
 
     @staticmethod
+    def dimOp(ser, rng, a, axis, error_name=None):
+        output_shape = [1]
+
+        if error_name == ErrorIf.WrongOutputType:
+            all_dtypes = [
+                DType.INT8,
+                DType.INT16,
+                DType.INT32,
+                DType.INT48,
+                DType.FP32,
+                DType.FP16,
+                DType.BF16,
+            ]
+            wrong_dtypes = list(set(all_dtypes))
+            outputDType = rng.choice(wrong_dtypes)
+        else:
+            outputDType = DType.SHAPE
+
+        return ser.addOutput(output_shape, outputDType)
+
+    @staticmethod
     def reshapeOp(ser, rng, a, shape, error_name=None):
         output_shape = shape.copy()
 
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 8ff62f1..f9df8d5 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -18,6 +18,7 @@
     DType.UINT16: {"str": "u16", "width": 16},
     DType.INT32: {"str": "i32", "width": 32},
     DType.INT48: {"str": "i48", "width": 48},
+    DType.SHAPE: {"str": "i64", "width": 64},
     DType.FP16: {"str": "f16", "width": 16},
     DType.BF16: {"str": "bf16", "width": 16},
     DType.FP32: {"str": "f32", "width": 32},