IVGCVSW-2267 Remove the input swizzling from ParseFusedBatchNorm

 * Removed the input swizzling when the data layout is NHWC
 * Split the unit test into NHWC and NCHW cases

Change-Id: I6b9fef70bc4ba5e01d14cbfaea3c842a289b0a0e
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index b40b054..73bdb65 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -1421,9 +1421,14 @@
     ParsedConstTfOperation<float>* varianceNode =
         boost::polymorphic_downcast<ParsedConstTfOperation<float> *>(inputs[4].m_IndexedValue);
 
+    const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format");
+
+    CHECK_DATA_FORMAT(nodeDef, dataFormat, "FusedBatchNorm");
+
     // The descriptor only has the epsilon attribute.
     BatchNormalizationDescriptor desc;
     desc.m_Eps = ReadMandatoryNodeFloatAttribute(nodeDef, "epsilon");
+    desc.m_DataLayout = dataFormat == "NHWC" ? DataLayout::NHWC : DataLayout::NCHW;
 
     // Data for the parsed tensor args (scale, offset, mean, variance) must be stored
     // locally until the layer is added.
@@ -1448,19 +1453,8 @@
 
     IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
 
-    const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format");
-
-    if (dataFormat == "NHWC")
-    {
-        const TensorInfo outputTensorInfo = armnnUtils::Permuted(inputSlot.GetTensorInfo(), NHWCToArmNN);
-        layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
-        layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name());
-    }
-    else
-    {
-        layer->GetOutputSlot(0).SetTensorInfo(inputSlot.GetTensorInfo());
-        inputSlot.Connect(layer->GetInputSlot(0));
-    }
+    layer->GetOutputSlot(0).SetTensorInfo(inputSlot.GetTensorInfo());
+    inputSlot.Connect(layer->GetInputSlot(0));
 
     return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
 }
diff --git a/src/armnnTfParser/test/FusedBatchNorm.cpp b/src/armnnTfParser/test/FusedBatchNorm.cpp
index bb9e3ed..98bdb26 100644
--- a/src/armnnTfParser/test/FusedBatchNorm.cpp
+++ b/src/armnnTfParser/test/FusedBatchNorm.cpp
@@ -7,11 +7,13 @@
 #include "armnnTfParser/ITfParser.hpp"
 #include "ParserPrototxtFixture.hpp"
 
+#include <array>
+
 BOOST_AUTO_TEST_SUITE(TensorflowParser)
 
 struct FusedBatchNormFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
 {
-    FusedBatchNormFixture()
+    explicit FusedBatchNormFixture(const std::string& dataLayout)
     {
         m_Prototext = "node { \n"
             "  name: \"graphInput\" \n"
@@ -143,33 +145,62 @@
             "  attr { \n"
             "    key: \"data_format\" \n"
             "    value { \n"
-            "      s: \"NHWC\" \n"
-            "    } \n"
-            "  } \n"
-            "  attr { \n"
-            "    key: \"epsilon\" \n"
-            "    value { \n"
-            "      f: 0.0010000000475 \n"
-            "    } \n"
-            "  } \n"
-            "  attr { \n"
-            "    key: \"is_training\" \n"
-            "    value { \n"
-            "      b: false \n"
-            "    } \n"
-            "  } \n"
-            "} \n";
+            "      s: \"";
+        m_Prototext.append(dataLayout);
+        m_Prototext.append("\" \n"
+                           "    } \n"
+                           "  } \n"
+                           "  attr { \n"
+                           "    key: \"epsilon\" \n"
+                           "    value { \n"
+                           "      f: 0.0010000000475 \n"
+                           "    } \n"
+                           "  } \n"
+                           "  attr { \n"
+                           "    key: \"is_training\" \n"
+                           "    value { \n"
+                           "      b: false \n"
+                           "    } \n"
+                           "  } \n"
+                           "} \n");
 
-        SetupSingleInputSingleOutput({1, 3, 3, 1}, "graphInput", "output");
+        // Set the input shape according to the data layout
+        std::array<unsigned int, 4> dims;
+        if (dataLayout == "NHWC")
+        {
+            dims = { 1u, 3u, 3u, 1u };
+        }
+        else // dataLayout == "NCHW"
+        {
+            dims = { 1u, 1u, 3u, 3u };
+        }
+
+        SetupSingleInputSingleOutput(armnn::TensorShape(4, dims.data()), "graphInput", "output");
     }
 };
 
-BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNorm, FusedBatchNormFixture)
+struct FusedBatchNormNhwcFixture : FusedBatchNormFixture
 {
-    RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9},             // Input data.
-               {-2.8277204f, -2.12079024f, -1.4138602f,
-                -0.7069301f, 0.0f, 0.7069301f,
-                1.4138602f, 2.12079024f, 2.8277204f});  // Expected output data.
+    FusedBatchNormNhwcFixture() : FusedBatchNormFixture("NHWC"){}
+};
+BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNhwc, FusedBatchNormNhwcFixture)
+{
+    RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 },               // Input data.
+               { -2.8277204f, -2.12079024f, -1.4138602f,
+                 -0.7069301f,  0.0f,         0.7069301f,
+                  1.4138602f,  2.12079024f,  2.8277204f }); // Expected output data.
+}
+
+struct FusedBatchNormNchwFixture : FusedBatchNormFixture
+{
+    FusedBatchNormNchwFixture() : FusedBatchNormFixture("NCHW"){}
+};
+BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNchw, FusedBatchNormNchwFixture)
+{
+    RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 },               // Input data.
+               { -2.8277204f, -2.12079024f, -1.4138602f,
+                 -0.7069301f,  0.0f,         0.7069301f,
+                  1.4138602f,  2.12079024f,  2.8277204f }); // Expected output data.
 }
 
 BOOST_AUTO_TEST_SUITE_END()