Implement Conv3D kernel.

Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: Ic16e918b1a2423ad563684e29ce70d9efdbf9c02
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index a150656..a0a1f04 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -482,6 +482,201 @@
 }
 
 template <DType InDtype, DType WeightDtype>
+OpConv3d<InDtype, WeightDtype>::OpConv3d(SubgraphTraverser* sgt_,
+                                         TosaAttributeBase* attribute_,
+                                         TosaQuantInfoBase* qinfo_,
+                                         uint64_t id_)
+    : GraphNode(sgt_, Op_CONV3D, id_)
+{
+    setRequiredOperands(3, 1);
+    setRequiredRank(5);
+
+    INIT_ATTRIBUTE(Conv);
+    INIT_QINFO(Conv);
+}
+
+template <DType InDtype, DType WeightDtype>
+OpConv3d<InDtype, WeightDtype>::~OpConv3d()
+{
+    if (attribute)
+        delete attribute;
+    if (qinfo)
+        delete qinfo;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes()
+{
+    if (validateRequiredOperands())
+        return 1;
+
+    if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+    {
+        return 1;
+    }
+
+    // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
+    if (inputs[2]->getRank() != 1)
+    {
+        printNodeValidationError("OpConv3d: bias tensor must be rank 1");
+    }
+
+    input  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+    weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
+    bias   = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
+    output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+    if (attribute->padding().size() != 6)
+    {
+        printNodeValidationError("OpConv3d: illegal size for attribute padding");
+        return 1;
+    }
+
+    if (attribute->stride().size() != 3)
+    {
+        printNodeValidationError("OpConv3d: illegal size for attribute stride");
+        return 1;
+    }
+
+    if (attribute->dilation().size() != 3)
+    {
+        printNodeValidationError("OpConv3d: illegal size for attribute dilation");
+        return 1;
+    }
+
+    return 0;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpConv3d<InDtype, WeightDtype>::eval()
+{
+    int in_batch    = this->input->getShape()[0];
+    int in_depth    = this->input->getShape()[1];
+    int in_height   = this->input->getShape()[2];
+    int in_width    = this->input->getShape()[3];
+    int in_channels = this->input->getShape()[4];
+
+    int f_out_channels = this->weight->getShape()[0];
+    int f_depth        = this->weight->getShape()[1];
+    int f_height       = this->weight->getShape()[2];
+    int f_width        = this->weight->getShape()[3];
+    int f_in_channels  = this->weight->getShape()[4];
+
+    int b_out_channels = this->bias->getShape()[0];
+
+    int out_batch    = this->output->getShape()[0];
+    int out_depth    = this->output->getShape()[1];
+    int out_height   = this->output->getShape()[2];
+    int out_width    = this->output->getShape()[3];
+    int out_channels = this->output->getShape()[4];
+
+    ERROR_IF(in_batch != out_batch, "OpConv3d: tensor batch mismatch %d != %d", in_batch, out_batch);
+    ERROR_IF(f_in_channels != in_channels, "OpConv3d: tensor input channel mismatch %d != %d", f_in_channels,
+             in_channels);
+    ERROR_IF(f_out_channels != out_channels, "OpConv3d: tensor output channel mismatch %d != %d", f_out_channels,
+             out_channels);
+    ERROR_IF(b_out_channels != out_channels, "OpConv3d: bias channel mismatch %d != %d", b_out_channels, out_channels);
+
+    int padding_d0     = this->attribute->padding()[0];
+    int padding_d1     = this->attribute->padding()[1];
+    int padding_top    = this->attribute->padding()[2];
+    int padding_bottom = this->attribute->padding()[3];
+    int padding_left   = this->attribute->padding()[4];
+    int padding_right  = this->attribute->padding()[5];
+    int stride_d       = this->attribute->stride()[0];
+    int stride_h       = this->attribute->stride()[1];
+    int stride_w       = this->attribute->stride()[2];
+    int dilation_d     = this->attribute->dilation()[0];
+    int dilation_h     = this->attribute->dilation()[1];
+    int dilation_w     = this->attribute->dilation()[2];
+
+    DEBUG_INFO(
+        OP,
+        "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], "
+        "stride=[%d,%d,%d], dilation=[%d,%d,%d], padding=[%d,%d,%d,%d,%d,%d]",
+        in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels,
+        out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_h, stride_w, dilation_d, dilation_h,
+        dilation_w, padding_d0, padding_d1, padding_top, padding_bottom, padding_left, padding_right);
+
+    Eigen::array<std::pair<int32_t, int32_t>, 5> padding;
+    padding[0] = std::make_pair(0, 0);
+    padding[1] = std::make_pair(padding_d0, padding_d1);
+    padding[2] = std::make_pair(padding_top, padding_bottom);
+    padding[3] = std::make_pair(padding_left, padding_right);
+    padding[4] = std::make_pair(0, 0);
+
+    TIn input_val      = this->input->getTensor();
+    TWeight weight_val = this->weight->getTensor();
+    if (this->qinfo)
+    {
+        input_val  = input_val - (InEigenType)this->qinfo->input_zp();
+        weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+    }
+
+    ETensor5<InEigenType> input_padded = input_val.pad(padding);
+
+    // 1. initialize with bias
+    Eigen::array<Eigen::Index, 5> reshape_dim;
+    reshape_dim.fill(1);
+    reshape_dim[4] = b_out_channels;
+
+    Eigen::array<Eigen::Index, 5> bcast;
+    bcast[0]                  = out_batch;
+    bcast[1]                  = out_depth;
+    bcast[2]                  = out_height;
+    bcast[3]                  = out_width;
+    bcast[4]                  = 1;
+    this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
+
+    // 2. direct convolution
+    AccEigenType acc = 0;
+    int d_idx, h_idx, w_idx;
+
+    for (int ob = 0; ob < out_batch; ob++)
+    {
+        for (int od = 0; od < out_depth; od++)
+        {
+            for (int oh = 0; oh < out_height; oh++)
+            {
+                for (int ow = 0; ow < out_width; ow++)
+                {
+                    for (int oc = 0; oc < out_channels; oc++)
+                    {
+                        acc = 0;
+                        for (int fd = 0; fd < f_depth; fd++)
+                        {
+                            d_idx = od * stride_d + fd * dilation_d;
+                            for (int fh = 0; fh < f_height; fh++)
+                            {
+                                h_idx = oh * stride_h + fh * dilation_h;
+                                for (int fw = 0; fw < f_width; fw++)
+                                {
+                                    w_idx = ow * stride_w + fw * dilation_w;
+                                    for (int ic = 0; ic < in_channels; ic++)
+                                    {
+                                        acc += ((AccEigenType)input_padded(ob, d_idx, h_idx, w_idx, ic) *
+                                                (AccEigenType)weight_val(oc, fd, fh, fw, ic));
+                                    }
+                                }
+                            }
+                        }
+                        this->output->getTensor()(ob, od, oh, ow, oc) = acc;
+                    }
+                }
+            }
+        }
+    }
+
+    if (AccDtype == DType_INT48)
+    {
+        this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
+        this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+    }
+
+    return GraphNode::eval();
+}
+
+template <DType InDtype, DType WeightDtype>
 OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
                                                            TosaAttributeBase* attribute_,
                                                            TosaQuantInfoBase* qinfo_,
@@ -1221,6 +1416,11 @@
 DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT8, INT8);
 DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT16, INT8);
 
+DEF_INSTANTIATE_TWO_TYPE(OpConv3d, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT8, INT4);
+DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT16, INT8);
+
 DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT);
 DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT8, INT4);
 DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT8, INT8);