Github #306 Treat data_format attribute as optional in TfParser::ParseFusedBatchNorm()

Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: I1c6583e4abb43b864dc636f8cdcd9011c763a6fe
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index d085ed8..51423bf 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -195,6 +195,19 @@
     return attriList;
 }
 
+std::string ReadOptionalNodeStringAttribute(const tensorflow::NodeDef& nodeDef,
+                                            const std::string& name,
+                                            const std::string& defaultValue = "")
+{
+    std::string attribValue = defaultValue;
+    ReadOptionalNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kS,
+        [&attribValue](const tensorflow::AttrValue& attrValue)
+    {
+        attribValue = attrValue.s();
+    });
+    return attribValue;
+}
+
 bool ReadOptionalNodeBoolAttribute(const tensorflow::NodeDef& nodeDef,
     const std::string& name,
     bool defaultValue = false)
@@ -1594,8 +1607,7 @@
     ParsedConstTfOperation<float>* varianceNode =
         boost::polymorphic_downcast<ParsedConstTfOperation<float> *>(inputs[4].m_IndexedValue);
 
-    const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format");
-
+    const std::string dataFormat = ReadOptionalNodeStringAttribute(nodeDef, "data_format", "NHWC");
     CHECK_DATA_FORMAT(nodeDef, dataFormat, "FusedBatchNorm");
 
     // The descriptor only has the epsilon attribute.
diff --git a/src/armnnTfParser/test/FusedBatchNorm.cpp b/src/armnnTfParser/test/FusedBatchNorm.cpp
index 98bdb26..b93a472 100644
--- a/src/armnnTfParser/test/FusedBatchNorm.cpp
+++ b/src/armnnTfParser/test/FusedBatchNorm.cpp
@@ -141,16 +141,22 @@
             "    value { \n"
             "      type: DT_FLOAT \n"
             "    } \n"
-            "  } \n"
-            "  attr { \n"
-            "    key: \"data_format\" \n"
-            "    value { \n"
-            "      s: \"";
-        m_Prototext.append(dataLayout);
-        m_Prototext.append("\" \n"
-                           "    } \n"
-                           "  } \n"
-                           "  attr { \n"
+            "  } \n";
+
+        // NOTE: we only explicitly set data_format when it is not the default NHWC
+        if (dataLayout != "NHWC")
+        {
+            m_Prototext.append("  attr { \n"
+                "    key: \"data_format\" \n"
+                "    value { \n"
+                "      s: \"");
+            m_Prototext.append(dataLayout);
+            m_Prototext.append("\" \n"
+                "    } \n"
+                "  } \n");
+        }
+
+        m_Prototext.append("  attr { \n"
                            "    key: \"epsilon\" \n"
                            "    value { \n"
                            "      f: 0.0010000000475 \n"