Implement Conv3D kernel.

Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: Ic16e918b1a2423ad563684e29ce70d9efdbf9c02
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 193b2af..3bc55a8 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -64,6 +64,12 @@
             DEF_FACTORY_TWO_TYPE(OpConv2d, INT8, INT8);
             DEF_FACTORY_TWO_TYPE(OpConv2d, INT16, INT8);
             break;
+        case Op_CONV3D:
+            DEF_FACTORY_TWO_TYPE(OpConv3d, FLOAT, FLOAT);
+            DEF_FACTORY_TWO_TYPE(OpConv3d, INT8, INT4);
+            DEF_FACTORY_TWO_TYPE(OpConv3d, INT8, INT8);
+            DEF_FACTORY_TWO_TYPE(OpConv3d, INT16, INT8);
+            break;
         case Op_DEPTHWISE_CONV2D:
             DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT);
             DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT8, INT4);
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index a150656..a0a1f04 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -482,6 +482,201 @@
 }
 
 template <DType InDtype, DType WeightDtype>
+OpConv3d<InDtype, WeightDtype>::OpConv3d(SubgraphTraverser* sgt_,
+                                         TosaAttributeBase* attribute_,
+                                         TosaQuantInfoBase* qinfo_,
+                                         uint64_t id_)
+    : GraphNode(sgt_, Op_CONV3D, id_)
+{
+    setRequiredOperands(3, 1);
+    setRequiredRank(5);
+
+    INIT_ATTRIBUTE(Conv);
+    INIT_QINFO(Conv);
+}
+
+template <DType InDtype, DType WeightDtype>
+OpConv3d<InDtype, WeightDtype>::~OpConv3d()
+{
+    if (attribute)
+        delete attribute;
+    if (qinfo)
+        delete qinfo;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes()
+{
+    if (validateRequiredOperands())
+        return 1;
+
+    if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+    {
+        return 1;
+    }
+
+    // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
+    if (inputs[2]->getRank() != 1)
+    {
+        printNodeValidationError("OpConv3d: bias tensor must be rank 1");
+    }
+
+    input  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+    weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
+    bias   = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
+    output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+    if (attribute->padding().size() != 6)
+    {
+        printNodeValidationError("OpConv3d: illegal size for attribute padding");
+        return 1;
+    }
+
+    if (attribute->stride().size() != 3)
+    {
+        printNodeValidationError("OpConv3d: illegal size for attribute stride");
+        return 1;
+    }
+
+    if (attribute->dilation().size() != 3)
+    {
+        printNodeValidationError("OpConv3d: illegal size for attribute dilation");
+        return 1;
+    }
+
+    return 0;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpConv3d<InDtype, WeightDtype>::eval()
+{
+    int in_batch    = this->input->getShape()[0];
+    int in_depth    = this->input->getShape()[1];
+    int in_height   = this->input->getShape()[2];
+    int in_width    = this->input->getShape()[3];
+    int in_channels = this->input->getShape()[4];
+
+    int f_out_channels = this->weight->getShape()[0];
+    int f_depth        = this->weight->getShape()[1];
+    int f_height       = this->weight->getShape()[2];
+    int f_width        = this->weight->getShape()[3];
+    int f_in_channels  = this->weight->getShape()[4];
+
+    int b_out_channels = this->bias->getShape()[0];
+
+    int out_batch    = this->output->getShape()[0];
+    int out_depth    = this->output->getShape()[1];
+    int out_height   = this->output->getShape()[2];
+    int out_width    = this->output->getShape()[3];
+    int out_channels = this->output->getShape()[4];
+
+    ERROR_IF(in_batch != out_batch, "OpConv3d: tensor batch mismatch %d != %d", in_batch, out_batch);
+    ERROR_IF(f_in_channels != in_channels, "OpConv3d: tensor input channel mismatch %d != %d", f_in_channels,
+             in_channels);
+    ERROR_IF(f_out_channels != out_channels, "OpConv3d: tensor output channel mismatch %d != %d", f_out_channels,
+             out_channels);
+    ERROR_IF(b_out_channels != out_channels, "OpConv3d: bias channel mismatch %d != %d", b_out_channels, out_channels);
+
+    int padding_d0     = this->attribute->padding()[0];
+    int padding_d1     = this->attribute->padding()[1];
+    int padding_top    = this->attribute->padding()[2];
+    int padding_bottom = this->attribute->padding()[3];
+    int padding_left   = this->attribute->padding()[4];
+    int padding_right  = this->attribute->padding()[5];
+    int stride_d       = this->attribute->stride()[0];
+    int stride_h       = this->attribute->stride()[1];
+    int stride_w       = this->attribute->stride()[2];
+    int dilation_d     = this->attribute->dilation()[0];
+    int dilation_h     = this->attribute->dilation()[1];
+    int dilation_w     = this->attribute->dilation()[2];
+
+    DEBUG_INFO(
+        OP,
+        "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], "
+        "stride=[%d,%d,%d], dilation=[%d,%d,%d], padding=[%d,%d,%d,%d,%d,%d]",
+        in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels,
+        out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_h, stride_w, dilation_d, dilation_h,
+        dilation_w, padding_d0, padding_d1, padding_top, padding_bottom, padding_left, padding_right);
+
+    Eigen::array<std::pair<int32_t, int32_t>, 5> padding;
+    padding[0] = std::make_pair(0, 0);
+    padding[1] = std::make_pair(padding_d0, padding_d1);
+    padding[2] = std::make_pair(padding_top, padding_bottom);
+    padding[3] = std::make_pair(padding_left, padding_right);
+    padding[4] = std::make_pair(0, 0);
+
+    TIn input_val      = this->input->getTensor();
+    TWeight weight_val = this->weight->getTensor();
+    if (this->qinfo)
+    {
+        input_val  = input_val - (InEigenType)this->qinfo->input_zp();
+        weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+    }
+
+    ETensor5<InEigenType> input_padded = input_val.pad(padding);
+
+    // 1. initialize with bias
+    Eigen::array<Eigen::Index, 5> reshape_dim;
+    reshape_dim.fill(1);
+    reshape_dim[4] = b_out_channels;
+
+    Eigen::array<Eigen::Index, 5> bcast;
+    bcast[0]                  = out_batch;
+    bcast[1]                  = out_depth;
+    bcast[2]                  = out_height;
+    bcast[3]                  = out_width;
+    bcast[4]                  = 1;
+    this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
+
+    // 2. direct convolution
+    AccEigenType acc = 0;
+    int d_idx, h_idx, w_idx;
+
+    for (int ob = 0; ob < out_batch; ob++)
+    {
+        for (int od = 0; od < out_depth; od++)
+        {
+            for (int oh = 0; oh < out_height; oh++)
+            {
+                for (int ow = 0; ow < out_width; ow++)
+                {
+                    for (int oc = 0; oc < out_channels; oc++)
+                    {
+                        acc = 0;
+                        for (int fd = 0; fd < f_depth; fd++)
+                        {
+                            d_idx = od * stride_d + fd * dilation_d;
+                            for (int fh = 0; fh < f_height; fh++)
+                            {
+                                h_idx = oh * stride_h + fh * dilation_h;
+                                for (int fw = 0; fw < f_width; fw++)
+                                {
+                                    w_idx = ow * stride_w + fw * dilation_w;
+                                    for (int ic = 0; ic < in_channels; ic++)
+                                    {
+                                        acc += ((AccEigenType)input_padded(ob, d_idx, h_idx, w_idx, ic) *
+                                                (AccEigenType)weight_val(oc, fd, fh, fw, ic));
+                                    }
+                                }
+                            }
+                        }
+                        this->output->getTensor()(ob, od, oh, ow, oc) = acc;
+                    }
+                }
+            }
+        }
+    }
+
+    if (AccDtype == DType_INT48)
+    {
+        this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
+        this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+    }
+
+    return GraphNode::eval();
+}
+
+template <DType InDtype, DType WeightDtype>
 OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
                                                            TosaAttributeBase* attribute_,
                                                            TosaQuantInfoBase* qinfo_,
@@ -1221,6 +1416,11 @@
 DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT8, INT8);
 DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT16, INT8);
 
+DEF_INSTANTIATE_TWO_TYPE(OpConv3d, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT8, INT4);
+DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT16, INT8);
+
 DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT);
 DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT8, INT4);
 DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT8, INT8);
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index eea351d..2174d62 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -109,6 +109,38 @@
 };
 
 template <DType InDtype, DType WeightDtype>
+class OpConv3d : public GraphNode
+{
+public:
+    OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+    virtual ~OpConv3d();
+
+    virtual int checkTensorAttributes() final;
+    virtual int eval() final;
+
+    static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
+
+    using InEigenType     = typename GetEigenType<InDtype>::type;
+    using WeightEigenType = typename GetEigenType<WeightDtype>::type;
+    using AccEigenType    = typename GetEigenType<AccDtype>::type;
+    using TIn             = Eigen::Tensor<InEigenType, 5>;
+    using TWeight         = Eigen::Tensor<WeightEigenType, 5>;
+    using TBias           = Eigen::Tensor<AccEigenType, 1>;
+    using TAcc            = Eigen::Tensor<AccEigenType, 5>;
+
+    static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
+    static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+
+protected:
+    TosaReference::TensorTemplate<TIn>* input;
+    TosaReference::TensorTemplate<TWeight>* weight;
+    TosaReference::TensorTemplate<TBias>* bias;
+    TosaReference::TensorTemplate<TAcc>* output;
+    tosa::TosaConvAttribute* attribute;
+    tosa::TosaConvQuantInfo* qinfo;
+};
+
+template <DType InDtype, DType WeightDtype>
 class OpDepthwiseConv2d : public GraphNode
 {
 public:
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index ef7bae6..4dba669 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -116,6 +116,7 @@
         switch (op->GetOp())
         {
             case Op_CONV2D:
+            case Op_CONV3D:
             case Op_DEPTHWISE_CONV2D:
             case Op_TRANSPOSE_CONV2D:
             case Op_FULLY_CONNECTED:
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 44582ac..9555195 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -257,6 +257,35 @@
         return [ifm_shape, filter_shape, bias_shape]
 
     @staticmethod
+    def tgConv3D(testGen, op, rank):
+        pl, const = op["operands"]
+
+        assert rank == 5
+
+        # IFM dimensions are NDHWC
+        ifm_shape = testGen.makeShape(rank)
+
+        # Constrict the batch size?
+        if testGen.args.max_batch_size:
+            ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+
+        # Get the filter depth/height/width from the operator parameters
+        filter_dhw = op["filter"]
+
+        # Generate a random OFM channel
+        ofm_channel = testGen.makeShape(1)[0]
+
+        # The filter dimensions are ODHWI
+        filter_shape = np.asarray(
+            [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
+        )
+
+        # The bias is OC
+        bias_shape = np.asarray([ofm_channel])
+
+        return [ifm_shape, filter_shape, bias_shape]
+
+    @staticmethod
     def tgTransposeConv2D(testGen, op, rank):
         pl, const = op["operands"]
 
@@ -463,6 +492,43 @@
         return arg_list
 
     @staticmethod
+    def agConv3D(testGen, opName, shapeList, dtype):
+        arg_list = []
+
+        ifm_shape = shapeList[0]
+        filter_shape = shapeList[1]
+
+        # Must be rank 5
+        assert len(ifm_shape) == 5
+        assert len(filter_shape) == 5
+
+        # Generate basic argument list now
+        # TODO: increase coverage
+        s = [1, 1, 1]
+        p = [0, 0, 0, 0, 0, 0]
+        d = [1, 1, 1]
+        arg_list.append(
+            (
+                "st{}{}{}_pad{}{}{}{}{}{}_dilat{}{}{}".format(
+                    s[0],
+                    s[1],
+                    s[2],
+                    p[0],
+                    p[1],
+                    p[2],
+                    p[3],
+                    p[4],
+                    p[5],
+                    d[0],
+                    d[1],
+                    d[2],
+                ),
+                [s, p, d],
+            )
+        )
+        return arg_list
+
+    @staticmethod
     def agTransposeConv2D(testGen, opName, shapeList, dtype):
         arg_list = []
 
@@ -1357,6 +1423,20 @@
         )
         return result_tens
 
+    def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
+        assert len(padding) == 6
+        result_tens = OutputShaper.conv3dOp(
+            self.ser, ifm, filter, strides, padding, dilations
+        )
+
+        attr = ts.TosaSerializerAttribute()
+        attr.ConvAttribute(padding, strides, dilations)
+
+        self.ser.addOperator(
+            op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
+        )
+        return result_tens
+
     def build_transpose_conv2d(
         self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
     ):
@@ -1859,7 +1939,9 @@
                 # Filter out the rank?
                 if rankFilter is not None and r not in rankFilter:
                     continue
-                if (
+                if opName.startswith("conv3d"):
+                    assert r == 5, "conv3d test must have input rank == 5"
+                elif (
                     rankFilter is None
                     and shapeFilter[0] is None
                     and r not in default_test_rank_range
@@ -2188,9 +2270,9 @@
     def createDynamicOpLists(self):
 
         # Dynamically create op lists for convolutions with a list of kernel sizes
-        KERNELS = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
+        KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
 
-        for k in KERNELS:
+        for k in KERNELS_2D:
             testName = "conv2d_{}x{}".format(k[0], k[1])
             self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
             self.TOSA_OP_LIST[testName]["filter"] = k
@@ -2210,6 +2292,13 @@
             self.TOSA_OP_LIST[testName]["filter"] = k
             self.TOSA_OP_LIST[testName]["template"] = False
 
+        KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
+        for k in KERNELS_3D:
+            testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
+            self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
+            self.TOSA_OP_LIST[testName]["filter"] = k
+            self.TOSA_OP_LIST[testName]["template"] = False
+
         # Delete any templates after having created any dynamic ops
         # This is a two-pass operation because it's bad practice to delete
         # keys from dictionaries while iterating
@@ -2286,7 +2375,7 @@
 
     TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
 
-    TYPE_CONV2D = [
+    TYPE_CONV = [
         [DType.INT8, DType.INT4, DType.INT32],
         [DType.INT8, DType.INT8, DType.INT32],
         [DType.INT16, DType.INT8, DType.INT48],
@@ -2319,11 +2408,20 @@
             "rank": (4, 4),
             "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D),
             "qgen": TosaQuantGen.qgConv,
-            "types": TYPE_CONV2D,
+            "types": TYPE_CONV,
             "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
             "template": True,
         },
-        # Conv3d TBD
+        # Templated operator.  Filled in by createDynamicOpLists
+        "conv3d_TEMPLATE": {
+            "op": Op.CONV3D,
+            "operands": (1, 2),
+            "rank": (5, 5),
+            "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv3D),
+            "qgen": TosaQuantGen.qgConv,
+            "types": TYPE_CONV,
+            "template": True,
+        },
         # Templated operator.  Filled in by createDynamicOpLists
         "depthwise_conv2d_TEMPLATE": {
             "op": Op.DEPTHWISE_CONV2D,
@@ -2336,7 +2434,7 @@
                 TosaArgGen.agConv2D,
             ),
             "qgen": TosaQuantGen.qgConv,
-            "types": TYPE_CONV2D,
+            "types": TYPE_CONV,
             "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
             "template": True,
         },
@@ -2346,7 +2444,7 @@
             "rank": (2, 2),
             "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
             "qgen": TosaQuantGen.qgConv,
-            "types": TYPE_CONV2D,
+            "types": TYPE_CONV,
         },
         "matmul": {
             "op": Op.MATMUL,
@@ -2375,7 +2473,7 @@
                 TosaArgGen.agTransposeConv2D,
             ),
             "qgen": TosaQuantGen.qgConv,
-            "types": TYPE_CONV2D,
+            "types": TYPE_CONV,
             "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
             "template": True,
         },
@@ -2909,6 +3007,50 @@
         return ser.addOutput(ofm_shape, out_dtype)
 
     @staticmethod
+    def conv3dOp(ser, ifm, filter, strides, padding, dilations):
+
+        # IFM:    NDHWC
+        # Filter: ODHWI
+        # OFM:    NDHWC
+
+        d = (
+            ifm.shape[1]
+            - filter.shape[1]
+            - (filter.shape[1] - 1) * (dilations[0] - 1)
+            + padding[0]
+            + padding[1]
+        ) // strides[0] + 1
+
+        h = (
+            ifm.shape[2]
+            - filter.shape[2]
+            - (filter.shape[2] - 1) * (dilations[1] - 1)
+            + padding[2]
+            + padding[3]
+        ) // strides[1] + 1
+
+        w = (
+            ifm.shape[3]
+            - filter.shape[3]
+            - (filter.shape[3] - 1) * (dilations[2] - 1)
+            + padding[4]
+            + padding[5]
+        ) // strides[2] + 1
+
+        ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
+
+        if ifm.dtype == DType.INT8:
+            out_dtype = DType.INT32
+        elif ifm.dtype == DType.INT16:
+            out_dtype = DType.INT48
+        elif ifm.dtype == DType.FLOAT:
+            out_dtype = DType.FLOAT
+        else:
+            raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
+
+        return ser.addOutput(ofm_shape, out_dtype)
+
+    @staticmethod
     def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
         # IFM:    NHWC
         # Filter: HWCM