MLCE-326 'Support Dilation in Conv2D in ONNX and Tensorflow Parsers'

Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: I4a0f07b1e8f80aff0d29405def1f33bde7944e31
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index f926013..d13a277 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -423,26 +423,27 @@
     "Assert"
 };
 
-inline void CalculateSamePadding(uint32_t inputSize, uint32_t stride,
-                                 uint32_t filterSize, bool samePadding,
-                                 uint32_t* paddingFront, uint32_t* paddingBack) {
-    *paddingFront = 0;
-    *paddingBack = 0;
-
-    if (samePadding) {
-        uint32_t outputSize = (inputSize + stride - 1) / stride;
-        uint32_t temp = (outputSize - 1) * stride + filterSize;
-        if (temp > inputSize) {
-            *paddingFront = (temp - inputSize) / 2;
-            *paddingBack = (temp - inputSize) - *paddingFront;
-        }
-    }
-}
-
-void CalcPadding(uint32_t input, uint32_t kernel, uint32_t stride, uint32_t& outPadHead, uint32_t& outPadTail,
+void CalcPadding(uint32_t inputSize,
+                 uint32_t filterSize,
+                 uint32_t stride,
+                 uint32_t dilation,
+                 uint32_t& paddingFront,
+                 uint32_t& paddingBack,
                  bool samePadding)
 {
-    CalculateSamePadding(input, stride, kernel, samePadding, &outPadHead, &outPadTail);
+    paddingFront = 0;
+    paddingBack = 0;
+    if (samePadding)
+    {
+        uint32_t outputSize = (inputSize + stride - 1) / stride;
+        uint32_t dilatedSize = filterSize + (dilation - 1) * (filterSize - 1);
+        uint32_t temp = (outputSize - 1) * stride + dilatedSize;
+        if (temp > inputSize)
+        {
+            paddingFront = (temp - inputSize) / 2;
+            paddingBack = (temp - inputSize) - paddingFront;
+        }
+    }
 }
 
 /// An Abstract base class which represents a single tensorflow operation (node)
@@ -1229,22 +1230,6 @@
     std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format");
     std::vector<uint32_t> strides = ReadMandatoryNodeUint32ListAttribute(nodeDef, "strides");
 
-    // Read the dilations, if present - only [1,1,1,1] (the default) is supported.
-    std::vector<uint32_t> dilations = ReadOptionalNodeUint32ListAttribute(nodeDef, "dilations");
-    if (!dilations.empty())
-    {
-        for (auto dilation : dilations)
-        {
-            if (dilation != 1u)
-            {
-                throw ParseException(
-                    fmt::format("ArmNN only supports Convolution layers with dilations [1,1,1,1] for {} {}",
-                                nodeDef.name(),
-                                CHECK_LOCATION().AsString()));
-            }
-        }
-    }
-
     Convolution2dDescriptor desc;
     desc.m_BiasEnabled = false;
 
@@ -1259,6 +1244,13 @@
     desc.m_StrideX = strides[dataLayoutIndexed.GetWidthIndex()];
     desc.m_StrideY = strides[dataLayoutIndexed.GetHeightIndex()];
 
+    std::vector<uint32_t> dilations = ReadOptionalNodeUint32ListAttribute(nodeDef, "dilations");
+    if (!dilations.empty())
+    {
+        desc.m_DilationX = dilations[dataLayoutIndexed.GetWidthIndex()];
+        desc.m_DilationY = dilations[dataLayoutIndexed.GetHeightIndex()];
+    }
+
     uint32_t inputHeight = inputTensorInfo.GetShape()[dataLayoutIndexed.GetHeightIndex()];
     uint32_t inputWidth  = inputTensorInfo.GetShape()[dataLayoutIndexed.GetWidthIndex()];
 
@@ -1296,22 +1288,24 @@
     if (paddingString == "SAME")
     {
         padding = true;
-
-        outputHeight = static_cast<uint32_t>(ceil(static_cast<float>(inputHeight) /
-                                                  static_cast<float>(desc.m_StrideY)));
-        outputWidth  = static_cast<uint32_t>(ceil(static_cast<float>(inputWidth) /
-                                                  static_cast<float>(desc.m_StrideX)));
     }
     else if (paddingString == "VALID")
     {
         padding = false;
-
-        outputHeight = static_cast<uint32_t>(ceil(static_cast<float>(inputHeight - weightHeight + 1) /
-                                                  static_cast<float>(desc.m_StrideY)));
-        outputWidth  = static_cast<uint32_t>(ceil(static_cast<float>(inputWidth - weightWidth + 1) /
-                                                  static_cast<float>(desc.m_StrideX)));
     }
 
+    CalcPadding(inputHeight, weightHeight, desc.m_StrideY, desc.m_DilationY, desc.m_PadTop, desc.m_PadBottom, padding);
+    CalcPadding(inputWidth, weightWidth, desc.m_StrideX, desc.m_DilationX, desc.m_PadLeft, desc.m_PadRight, padding);
+
+    // Calculate output height and  width
+    unsigned int dilatedFilterWidth = weightWidth + (desc.m_DilationX - 1) * (weightWidth - 1);
+    unsigned int readWidth = (inputWidth + desc.m_PadLeft + desc.m_PadRight) - dilatedFilterWidth;
+    outputWidth = 1 + (readWidth / desc.m_StrideX);
+
+    unsigned int dilatedFilterHeight = weightHeight + (desc.m_DilationY - 1) * (weightHeight - 1);
+    unsigned int readHeight = (inputHeight + desc.m_PadTop + desc.m_PadBottom) - dilatedFilterHeight;
+    outputHeight = 1 + (readHeight / desc.m_StrideY);
+
     switch (dataLayout)
     {
     case DataLayout::NHWC:
@@ -1331,9 +1325,6 @@
         break;
     }
 
-    CalcPadding(inputHeight, weightHeight, desc.m_StrideY, desc.m_PadTop, desc.m_PadBottom, padding);
-    CalcPadding(inputWidth, weightWidth, desc.m_StrideX, desc.m_PadLeft, desc.m_PadRight, padding);
-
     IConnectableLayer* layer = m_Network->AddConvolution2dLayer(desc,
                                                                 weightTensor,
                                                                 EmptyOptional(),
@@ -1382,6 +1373,12 @@
 
     desc.m_StrideX = strides[dataLayoutIndexed.GetWidthIndex()];
     desc.m_StrideY = strides[dataLayoutIndexed.GetHeightIndex()];
+    std::vector<uint32_t> dilations = ReadOptionalNodeUint32ListAttribute(nodeDef, "dilations");
+    if (!dilations.empty())
+    {
+        desc.m_DilationX = dilations[dataLayoutIndexed.GetWidthIndex()];
+        desc.m_DilationY = dilations[dataLayoutIndexed.GetHeightIndex()];
+    }
 
     uint32_t inputHeight = inputTensorInfo.GetShape()[dataLayoutIndexed.GetHeightIndex()];
     uint32_t inputWidth  = inputTensorInfo.GetShape()[dataLayoutIndexed.GetWidthIndex()];
@@ -1416,22 +1413,24 @@
     if (paddingString == "SAME")
     {
         padding = true;
-
-        outputHeight = static_cast<uint32_t>(ceil(static_cast<float>(inputHeight) /
-                                                  static_cast<float>(desc.m_StrideY)));
-        outputWidth  = static_cast<uint32_t>(ceil(static_cast<float>(inputWidth) /
-                                                  static_cast<float>(desc.m_StrideX)));
     }
     else if (paddingString == "VALID")
     {
         padding = false;
-
-        outputHeight = static_cast<uint32_t>(ceil(static_cast<float>(inputHeight - weightHeight + 1) /
-                                                  static_cast<float>(desc.m_StrideY)));
-        outputWidth  = static_cast<uint32_t>(ceil(static_cast<float>(inputWidth - weightWidth + 1) /
-                                                  static_cast<float>(desc.m_StrideX)));
     }
 
+    CalcPadding(inputHeight, weightHeight, desc.m_StrideY, desc.m_DilationY, desc.m_PadTop, desc.m_PadBottom, padding);
+    CalcPadding(inputWidth, weightWidth, desc.m_StrideX, desc.m_DilationX, desc.m_PadLeft, desc.m_PadRight, padding);
+
+    // Calculate output height and  width
+    unsigned int dilatedFilterWidth = weightWidth + (desc.m_DilationX - 1) * (weightWidth - 1);
+    unsigned int readWidth = (inputWidth + desc.m_PadLeft + desc.m_PadRight) - dilatedFilterWidth;
+    outputWidth = 1 + (readWidth / desc.m_StrideX);
+
+    unsigned int dilatedFilterHeight = weightHeight + (desc.m_DilationY - 1) * (weightHeight - 1);
+    unsigned int readHeight = (inputHeight + desc.m_PadTop + desc.m_PadBottom) - dilatedFilterHeight;
+    outputHeight = 1 + (readHeight / desc.m_StrideY);
+
     switch (dataLayout)
     {
         case DataLayout::NHWC:
@@ -1451,9 +1450,6 @@
             break;
     }
 
-    CalcPadding(inputHeight, weightHeight, desc.m_StrideY, desc.m_PadTop, desc.m_PadBottom, padding);
-    CalcPadding(inputWidth, weightWidth, desc.m_StrideX, desc.m_PadLeft, desc.m_PadRight, padding);
-
     IConnectableLayer* layer = m_Network->AddDepthwiseConvolution2dLayer(desc,
                                                                          weightTensor,
                                                                          EmptyOptional(),
@@ -3094,9 +3090,9 @@
             break;
     }
 
-    CalcPadding(inputWidth, pooling2dDescriptor.m_PoolWidth, pooling2dDescriptor.m_StrideX,
+    CalcPadding(inputWidth, pooling2dDescriptor.m_PoolWidth, pooling2dDescriptor.m_StrideX, 1u,
                 pooling2dDescriptor.m_PadLeft, pooling2dDescriptor.m_PadRight, padding);
-    CalcPadding(inputHeight, pooling2dDescriptor.m_PoolHeight, pooling2dDescriptor.m_StrideY,
+    CalcPadding(inputHeight, pooling2dDescriptor.m_PoolHeight, pooling2dDescriptor.m_StrideY, 1u,
                 pooling2dDescriptor.m_PadTop, pooling2dDescriptor.m_PadBottom, padding);
 
 
diff --git a/src/armnnTfParser/test/Convolution2d.cpp b/src/armnnTfParser/test/Convolution2d.cpp
index cf71489..c58615f 100644
--- a/src/armnnTfParser/test/Convolution2d.cpp
+++ b/src/armnnTfParser/test/Convolution2d.cpp
@@ -37,7 +37,22 @@
                                 "        i: " + std::to_string(stride) + " \n");
         }
 
-        std::string dilationString = std::to_string(dilation);
+        std::string dilationString;
+        if (dataLayout == "NHWC")
+        {
+            dilationString.append("        i: 1 \n"
+                                  "        i: " + std::to_string(dilation) + " \n"
+                                  "        i: " + std::to_string(dilation) + " \n"
+                                  "        i: 1 \n");
+        }
+        else // dataLayout == "NCHW"
+        {
+            dilationString.append("        i: 1 \n"
+                                  "        i: 1 \n"
+                                  "        i: " + std::to_string(dilation) + " \n"
+                                  "        i: " + std::to_string(dilation) + " \n");
+        }
+
         m_Prototext = "node { \n"
             "    name: \"graphInput\" \n"
             "    op: \"Placeholder\" \n"
@@ -130,16 +145,10 @@
             m_Prototext.append("  attr { \n"
                                "    key: \"dilations\" \n"
                                "    value { \n"
-                               "      list { \n"
-                               "        i: 1 \n"
-                               "        i: ");
+                               "      list { \n");
             m_Prototext.append(dilationString);
-            m_Prototext.append(" \n"
-                               "        i: ");
-            m_Prototext.append(dilationString);
-            m_Prototext.append(" \n"
-                               "        i: 1 \n"
-                               "      } \n"
+
+            m_Prototext.append("      } \n"
                                "    } \n"
                                "  } \n");
         }
@@ -167,7 +176,6 @@
     }
 };
 
-
 struct Convolution2dNhwcSameFixture : Convolution2dFixture
 {
     Convolution2dNhwcSameFixture() : Convolution2dFixture("NHWC", "SAME", 1){}
@@ -262,118 +270,174 @@
     RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
 }
 
-
-BOOST_AUTO_TEST_CASE(ParseConv2dDilation2)
+struct Convolution2dDilationFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
 {
-    const char* prototext = ""
-        "node {\n"
-        "  name: \"graphInput\"\n"
-        "  op: \"Placeholder\"\n"
-        "  attr {\n"
-        "    key: \"dtype\"\n"
-        "    value {\n"
-        "      type: DT_FLOAT\n"
-        "    }\n"
-        "  }\n"
-        "  attr {\n"
-        "    key: \"shape\"\n"
-        "    value {\n"
-        "      shape {\n"
-        "      }\n"
-        "    }\n"
-        "  }\n"
-        "}\n"
-        "node {\n"
-        "  name: \"Const_1\"\n"
-        "  op: \"Const\"\n"
-        "  attr {\n"
-        "    key: \"dtype\"\n"
-        "    value {\n"
-        "      type: DT_FLOAT\n"
-        "    }\n"
-        "  }\n"
-        "  attr {\n"
-        "    key: \"value\"\n"
-        "    value {\n"
-        "      tensor {\n"
-        "        dtype: DT_FLOAT\n"
-        "        tensor_shape {\n"
-        "          dim {\n"
-        "            size: 1\n"
-        "          }\n"
-        "          dim {\n"
-        "            size: 3\n"
-        "          }\n"
-        "          dim {\n"
-        "            size: 1\n"
-        "          }\n"
-        "          dim {\n"
-        "            size: 1\n"
-        "          }\n"
-        "        }\n"
-        "        tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\"\n"
-        "      }\n"
-        "    }\n"
-        "  }\n"
-        "}\n"
-        "node {\n"
-        "  name: \"potato\"\n"
-        "  op: \"Conv2D\"\n"
-        "  input: \"graphInput\"\n"
-        "  input: \"Const_1\"\n"
-        "  attr {\n"
-        "    key: \"T\"\n"
-        "    value {\n"
-        "      type: DT_FLOAT\n"
-        "    }\n"
-        "  }\n"
-        "  attr {\n"
-        "    key: \"data_format\"\n"
-        "    value {\n"
-        "      s: \"NHWC\"\n"
-        "    }\n"
-        "  }\n"
-        "  attr {\n"
-        "    key: \"padding\"\n"
-        "    value {\n"
-        "      s: \"SAME\"\n"
-        "    }\n"
-        "  }\n"
-        "  attr {\n"
-        "    key: \"strides\"\n"
-        "    value {\n"
-        "      list {\n"
-        "        i: 1\n"
-        "        i: 1\n"
-        "        i: 1\n"
-        "        i: 1\n"
-        "      }\n"
-        "    }\n"
-        "  }\n"
-        "  attr {\n"
-        "    key: \"dilations\"\n"
-        "    value {\n"
-        "      list {\n"
-        "        i: 1\n"
-        "        i: 2\n"
-        "        i: 2\n"
-        "        i: 1\n"
-        "      }\n"
-        "    }\n"
-        "  }\n"
-        "  attr {\n"
-        "    key: \"use_cudnn_on_gpu\"\n"
-        "    value {\n"
-        "      b: false\n"
-        "    }\n"
-        "  }\n"
-        "}\n";
+    explicit Convolution2dDilationFixture(const std::string& dataLayout, const std::string& paddingType)
+        : Convolution2dDilationFixture(dataLayout, paddingType, 1)
+    {}
 
-    std::map<std::string, armnn::TensorShape> inputShapes;
-    armnn::TensorShape tensorShape = { 1, 3, 3, 1 };
-    inputShapes["graphInput"] = tensorShape;
-    armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
-    BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, inputShapes, { "potato" }), armnn::ParseException);
+    explicit Convolution2dDilationFixture(const std::string& dataLayout, const std::string& paddingType,
+                                  int stride, int dilation = 0)
+    {
+        std::string strideString;
+        if (dataLayout == "NHWC")
+        {
+            strideString.append("        i: 1 \n"
+                                "        i: " + std::to_string(stride) + " \n"
+                                "        i: " + std::to_string(stride) + " \n"
+                                "        i: 1 \n");
+        }
+        else // dataLayout == "NCHW"
+        {
+            strideString.append("        i: 1 \n"
+                                "        i: 1 \n"
+                                "        i: " + std::to_string(stride) + " \n"
+                                "        i: " + std::to_string(stride) + " \n");
+        }
+
+        std::string dilationString;
+        if (dataLayout == "NHWC")
+        {
+            dilationString.append("        i: 1 \n"
+                                  "        i: " + std::to_string(dilation) + " \n"
+                                  "        i: " + std::to_string(dilation) + " \n"
+                                  "        i: 1 \n");
+        }
+        else // dataLayout == "NCHW"
+        {
+            dilationString.append("        i: 1 \n"
+                                  "        i: 1 \n"
+                                  "        i: " + std::to_string(dilation) + " \n"
+                                  "        i: " + std::to_string(dilation) + " \n");
+        }
+
+        m_Prototext = "node { \n"
+                      "    name: \"graphInput\" \n"
+                      "    op: \"Placeholder\" \n"
+                      "    attr { \n"
+                      "      key: \"dtype\" \n"
+                      "      value { \n"
+                      "        type: DT_FLOAT \n"
+                      "      } \n"
+                      "    } \n"
+                      "    attr { \n"
+                      "      key: \"shape\" \n"
+                      "      value { \n"
+                      "        shape { \n"
+                      "        } \n"
+                      "      } \n"
+                      "    } \n"
+                      "  } \n"
+                      "  node { \n"
+                      "  name: \"Const_1\" \n"
+                      "  op: \"Const\" \n"
+                      "  attr { \n"
+                      "    key: \"dtype\" \n"
+                      "    value { \n"
+                      "      type: DT_FLOAT \n"
+                      "    } \n"
+                      "  } \n"
+                      "  attr { \n"
+                      "    key: \"value\" \n"
+                      "    value { \n"
+                      "      tensor { \n"
+                      "        dtype: DT_FLOAT \n"
+                      "        tensor_shape { \n"
+                      "          dim { \n"
+                      "            size: 3 \n"
+                      "          } \n"
+                      "          dim { \n"
+                      "            size: 1 \n"
+                      "          } \n"
+                      "          dim { \n"
+                      "            size: 1 \n"
+                      "          } \n"
+                      "          dim { \n"
+                      "            size: 1 \n"
+                      "          } \n"
+                      "        } \n"
+                      "        tensor_content: \"\\001\\000\\000?\\000\\000\\000?\\001\\000\\000?\" \n"
+                      "      } \n"
+                      "    } \n"
+                      "  } \n"
+                      "} \n"
+                      "node { \n"
+                      "  name: \"potato\" \n"
+                      "  op: \"Conv2D\" \n"
+                      "  input: \"graphInput\" \n"
+                      "  input: \"Const_1\" \n"
+                      "  attr { \n"
+                      "    key: \"T\" \n"
+                      "    value { \n"
+                      "      type: DT_FLOAT \n"
+                      "    } \n"
+                      "  } \n"
+                      "  attr { \n"
+                      "    key: \"data_format\" \n"
+                      "    value { \n"
+                      "      s: \"";
+        m_Prototext.append(dataLayout);
+        m_Prototext.append("\"\n"
+                           "    } \n"
+                           "  } \n"
+                           "  attr { \n"
+                           "    key: \"padding\" \n"
+                           "    value { \n"
+                           "      s: \"");
+        m_Prototext.append(paddingType);
+        m_Prototext.append("\"\n"
+                           "    } \n"
+                           "  } \n"
+                           "  attr { \n"
+                           "    key: \"strides\" \n"
+                           "    value { \n"
+                           "      list { \n");
+        m_Prototext.append(strideString);
+
+        m_Prototext.append("      } \n"
+                           "    } \n"
+                           "  } \n");
+
+        if (dilation > 0)
+        {
+            m_Prototext.append("  attr { \n"
+                               "    key: \"dilations\" \n"
+                               "    value { \n"
+                               "      list { \n");
+            m_Prototext.append(dilationString);
+
+            m_Prototext.append("      } \n"
+                               "    } \n"
+                               "  } \n");
+        }
+        m_Prototext.append("  attr { \n"
+                           "    key: \"use_cudnn_on_gpu\" \n"
+                           "    value { \n"
+                           "      b: false \n"
+                           "    } \n"
+                           "  } \n"
+                           "} \n");
+
+        // Manual height computation based on stride parameter.
+        std::array<unsigned int, 4> dims = { 1u, 1u, 6u, 6u };;
+
+        SetupSingleInputSingleOutput(armnn::TensorShape(4, dims.data()), "graphInput", "potato");
+    }
+};
+
+struct Convolution2dDilation2NchwValidFixture : Convolution2dDilationFixture
+{
+    Convolution2dDilation2NchwValidFixture() : Convolution2dDilationFixture("NCHW", "VALID", 1, 2){}
+};
+BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation2NchwValid, Convolution2dDilation2NchwValidFixture)
+{
+    RunTest<4>({1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
+                7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+                1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
+                7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+                1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
+                7.0, 8.0, 9.0, 10.0, 11.0, 12.0},
+               {1.5f, 3.0f, 4.5f, 6.0f, 7.5f, 9.0f, 10.5f, 12.f, 13.5f, 15.0f, 16.5f, 18.0f});
 }