IVGCVSW-4496 Add Flatten support to ONNX parser

 * Added ParseFlatten method
 * Added Read int64 attribute method
 - Modified ComputeReshapeInfo method
 - Modified ParseReshape
 * Added unit tests
 - Reorganised OnnxParser.cpp/.hpp

Signed-off-by: Ryan OShea <Ryan.OShea2@arm.com>
Change-Id: I8a9553438dd1e8c702d821b093587e0074c027d5
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 9e31a03..20e8717 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -820,20 +820,21 @@
 
     if(BUILD_ONNX_PARSER AND ARMNNREF)
         list(APPEND unittest_sources
-            src/armnnOnnxParser/test/Constructor.cpp
-            src/armnnOnnxParser/test/CreateNetwork.cpp
-            src/armnnOnnxParser/test/ProtoxtFixture.cpp
+            src/armnnOnnxParser/test/Addition.cpp
+            src/armnnOnnxParser/test/BatchNorm.cpp
             src/armnnOnnxParser/test/Clip.cpp
             src/armnnOnnxParser/test/Const.cpp
-            src/armnnOnnxParser/test/Pooling.cpp
-            src/armnnOnnxParser/test/Reshape.cpp
-            src/armnnOnnxParser/test/Relu.cpp
+            src/armnnOnnxParser/test/Constructor.cpp
             src/armnnOnnxParser/test/Conv2D.cpp
-            src/armnnOnnxParser/test/Addition.cpp
+            src/armnnOnnxParser/test/CreateNetwork.cpp
+            src/armnnOnnxParser/test/DepthConv.cpp
+            src/armnnOnnxParser/test/Flatten.cpp
             src/armnnOnnxParser/test/FullyConnected.cpp
             src/armnnOnnxParser/test/GetInputsOutputs.cpp
-            src/armnnOnnxParser/test/BatchNorm.cpp
-            src/armnnOnnxParser/test/DepthConv.cpp
+            src/armnnOnnxParser/test/Pooling.cpp
+            src/armnnOnnxParser/test/ProtoxtFixture.cpp
+            src/armnnOnnxParser/test/Relu.cpp
+            src/armnnOnnxParser/test/Reshape.cpp
             )
     endif()
 
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 455bd87..a07a899 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -119,6 +119,19 @@
     }
 }
 
+int64_t ReadOptionalNodeInt64Attribute(const onnx::NodeProto& node,
+                                       const std::string& name,
+                                       const int64_t defaultValue = 0)
+{
+    int64_t attribValue = defaultValue;
+    ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::INT,
+                                  [&attribValue](const onnx::AttributeProto& attrValue)
+                                      {
+                                          attribValue = attrValue.i();
+                                      });
+    return attribValue;
+}
+
 std::vector<uint32_t> ReadMandatoryNodeUint32ListAttribute(const onnx::NodeProto& node,
                                                            const std::string& name)
 {
@@ -297,14 +310,14 @@
     }
 }
 
-TensorInfo ComputeReshapeInfo(const onnx::TensorProto& targetShapeTensor,
+TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor,
                               const TensorShape& inShape,
                               const std::string& outName)
 {
     std::vector<int> targetDims;
-    for(int i = 0; i < targetShapeTensor.int64_data_size(); ++i)
+    for(uint i = 0; i < targetShapeTensor.GetNumDimensions(); ++i)
     {
-        int val = CHECKED_INT32(targetShapeTensor.int64_data(i));
+        int val = CHECKED_INT32(targetShapeTensor[i]);
         if(val == 0)
         {
             targetDims.push_back(static_cast<int>(inShape[static_cast<uint>(i)]));
@@ -362,6 +375,7 @@
     { "LeakyRelu",             &OnnxParser::ParseLeakyRelu },
     { "Conv",                  &OnnxParser::ParseConv },
     { "Add",                   &OnnxParser::ParseAdd },
+    { "Flatten",               &OnnxParser::ParseFlatten},
 };
 
 template<typename TypePair, typename Location>
@@ -803,6 +817,66 @@
     m_TensorsInfo[name].m_info->SetShape(TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data()));
 }
 
+void OnnxParser::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const Convolution2dDescriptor& convDesc)
+{
+    ARMNN_ASSERT(node.op_type() == "Conv");
+
+    DepthwiseConvolution2dDescriptor desc;
+    desc.m_PadLeft      = convDesc.m_PadLeft;
+    desc.m_PadRight     = convDesc.m_PadRight;
+    desc.m_PadTop       = convDesc.m_PadTop;
+    desc.m_PadBottom    = convDesc.m_PadBottom;
+    desc.m_StrideX      = convDesc.m_StrideX;
+    desc.m_StrideY      = convDesc.m_StrideY;
+    desc.m_BiasEnabled  = convDesc.m_BiasEnabled;
+
+    armnn::IConnectableLayer* layer;
+    auto weightTensor = CreateConstTensor(node.input(1));
+    TensorShape& weightShape = weightTensor.first.GetShape();
+    weightShape[1] = weightShape[0];
+    weightShape[0] = 1;
+    m_TensorsInfo[node.input(1)].m_info->SetShape(weightShape);
+
+    if (node.input_size() == 3)
+    {
+        if(!m_TensorsInfo[node.input(2)].isConstant())
+        {
+            throw ParseException(boost::str(
+                boost::format("Bias '%1%' should be constant in Conv layer '%2%' %3%")
+                % node.input(2)
+                % node.name()
+                % CHECK_LOCATION().AsString()));
+        }
+        desc.m_BiasEnabled = true;
+        auto biasTensor = CreateConstTensor(node.input(2));
+        layer = m_Network->AddDepthwiseConvolution2dLayer(desc,
+                                                          weightTensor.first,
+                                                          Optional<ConstTensor>(biasTensor.first),
+                                                          node.name().c_str());
+    }
+    else
+    {
+        layer = m_Network->AddDepthwiseConvolution2dLayer(desc,
+                                                          weightTensor.first,
+                                                          EmptyOptional(),
+                                                          node.name().c_str());
+    }
+    ARMNN_ASSERT(layer != nullptr);
+
+    auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
+                                        { m_TensorsInfo[node.input(0)].m_info->GetShape(),
+                                          m_TensorsInfo[node.input(1)].m_info->GetShape() });
+
+    layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
+
+    // register the input connection slots for the layer, connections are made after all layers have been created
+    // only the tensors for the inputs are relevant, exclude the const tensors
+    RegisterInputSlots(layer, {node.input(0)});
+
+    // register the output connection slots for the layer, connections are made after all layers have been created
+    RegisterOutputSlots(layer, {node.output(0)});
+}
+
 void OnnxParser::AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode)
 {
 
@@ -881,84 +955,6 @@
     }
 }
 
-void OnnxParser::CreateConstantLayer(const std::string& tensorName, const std::string& layerName)
-{
-    auto armnnTensor = CreateConstTensor(tensorName);
-
-    IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str());
-    layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo());
-    RegisterOutputSlots(layer, {tensorName});
-}
-
-void OnnxParser::ParseConstant(const onnx::NodeProto& node)
-{
-    CHECK_VALID_SIZE(static_cast<size_t>(node.attribute_size()), 1);
-    if (!node.attribute(0).has_t())
-    {
-        throw ParseException(boost::str(
-              boost::format("Value not found for Constant node '%1%' %2%")
-              % node.name()
-              % CHECK_LOCATION().AsString()));
-    }
-    const onnx::TensorProto& onnxTensor = node.attribute(0).t();
-
-    //ONNX can have Float16 and double constant nodes but ArmNN only supports float32
-    CHECK_VALID_DATATYPE(node.name(), onnxTensor.name(),
-                         static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::FLOAT);
-
-    //Register this as a m_ConstParam so we know we can use it as a constant param in future layers.
-    m_TensorsInfo[node.output(0)].m_tensor = std::make_unique<const onnx::TensorProto>(onnxTensor);
-    m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(ToTensorInfo(onnxTensor));
-    m_TensorsInfo[node.output(0)].m_dtype = static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type());
-
-    CreateConstantLayer(node.output(0), node.name());
-}
-
-void OnnxParser::ParseMaxPool(const onnx::NodeProto& node)
-{
-    Pooling2dDescriptor desc;
-    desc.m_PoolType = PoolingAlgorithm::Max;
-    desc.m_PaddingMethod = PaddingMethod::Exclude;
-    AddPoolingLayer(node, desc);
-}
-
-void OnnxParser::ParseGlobalAveragePool(const onnx::NodeProto& node)
-{
-    Pooling2dDescriptor desc = Pooling2dDescriptor();
-    desc.m_PoolType = PoolingAlgorithm::Average;
-
-    //kernel size is the same as input
-    TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
-    desc.m_PoolWidth  = inputShape[3];
-    desc.m_PoolHeight = inputShape[2];
-
-    IConnectableLayer* layer = m_Network->AddPooling2dLayer(desc, node.name().c_str());
-    ARMNN_ASSERT(layer != nullptr);
-
-    auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape});
-    layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
-
-    // register the input connection slots for the layer, connections are made after all layers have been created
-    // only the tensors for the inputs are relevant, exclude the const tensors
-    RegisterInputSlots(layer, {node.input(0)});
-
-    // register the output connection slots for the layer, connections are made after all layers have been created
-    RegisterOutputSlots(layer, {node.output(0)});
-}
-
-void OnnxParser::ParseAveragePool(const onnx::NodeProto& node)
-{
-    Pooling2dDescriptor desc;
-    desc.m_PoolType = PoolingAlgorithm::Average;
-
-    uint32_t count_include_pad = 0;
-    count_include_pad = ReadOptionalNodeUint32Attribute(node, "count_include_pad");
-    if(count_include_pad) {
-        desc.m_PaddingMethod = PaddingMethod::IgnoreValue;
-    }
-    AddPoolingLayer(node, desc);
-}
-
 void OnnxParser::AddPoolingLayer(const onnx::NodeProto& node, Pooling2dDescriptor& desc)
 {
 
@@ -1006,7 +1002,7 @@
             {
                 throw ParseException(boost::str(
                     boost::format("Invalid auto_pad attribute for node %1%. "
-                    "Only SAME_UPPER, SAME_LOWER or VALID supported and found %2% %3%")
+                                  "Only SAME_UPPER, SAME_LOWER or VALID supported and found %2% %3%")
                     % node.name()
                     % paddingString
                     % CHECK_LOCATION().AsString()));
@@ -1040,6 +1036,38 @@
     RegisterOutputSlots(layer, {node.output(0)});
 }
 
+std::pair<std::string, std::string> OnnxParser::AddPrepareBroadcast(const std::string& input0,
+                                                                    const std::string& input1)
+{
+    std::pair<std::string, std::string> inputs = std::make_pair(input0, input1);
+
+    TensorShape input0Shape = m_TensorsInfo[input0].m_info->GetShape();
+    TensorShape input1Shape = m_TensorsInfo[input1].m_info->GetShape();
+
+    if(input1Shape.GetNumDimensions() < input0Shape.GetNumDimensions())
+    {
+        auto outputName = boost::str(boost::format("reshape_output_%1%") % input1);
+        PrependForBroadcast(outputName, input1, input0);
+        inputs.second = outputName;
+    }
+    else if(input0Shape.GetNumDimensions() < input1Shape.GetNumDimensions())
+    {
+        auto outputName = boost::str(boost::format("reshape_output_%1%") % input0);
+        PrependForBroadcast(outputName, input0, input1);
+        inputs.first = outputName;
+    }
+    return inputs;
+}
+
+void OnnxParser::CreateConstantLayer(const std::string& tensorName, const std::string& layerName)
+{
+    auto armnnTensor = CreateConstTensor(tensorName);
+
+    IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str());
+    layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo());
+    RegisterOutputSlots(layer, {tensorName});
+}
+
 void OnnxParser::CreateReshapeLayer(const std::string& inputName,
                                     const std::string& outputName,
                                     const std::string& layerName)
@@ -1060,51 +1088,6 @@
     RegisterOutputSlots(layer, {outputName});
 }
 
-void OnnxParser::ParseReshape(const onnx::NodeProto& node)
-{
-    CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
-    CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
-
-    CHECK_VALID_DATATYPE(node.name(), node.input(0),
-                         m_TensorsInfo[node.input(0)].m_dtype,
-                         onnx::TensorProto::FLOAT); //input
-    CHECK_VALID_DATATYPE(node.name(), node.input(1),
-                         m_TensorsInfo[node.input(1)].m_dtype,
-                         onnx::TensorProto::INT64); //shape
-
-    if(!m_TensorsInfo[node.input(1)].isConstant())
-    {
-        throw ParseException(boost::str(
-            boost::format("Shape '%1%' should be constant in Reshape layer '%2%' %3%")
-                          % node.input(1)
-                          % node.name()
-                          % CHECK_LOCATION().AsString()));
-    }
-
-    if(m_TensorsInfo[node.input(0)].isConstant())
-    {
-        //make a new cst tensor -> move the data to the output tensor (the shape is already good in the output tensor)
-        if(m_TensorsInfo.count(node.output(0)) == 0)
-        {
-            m_TensorsInfo[node.output(0)] = OnnxTensor();
-        }
-        m_TensorsInfo[node.output(0)].m_tensor =
-            std::make_unique<onnx::TensorProto>(*m_TensorsInfo[node.input(0)].m_tensor);
-    }
-    else
-    {
-        TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
-
-        if(m_TensorsInfo.count(node.output(0)) == 0 || m_TensorsInfo[node.output(0)].m_info == nullptr)
-        {
-            auto outInfo = ComputeReshapeInfo(*m_TensorsInfo[node.input(1)].m_tensor, inputShape, node.output(0));
-            m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
-        }
-
-        CreateReshapeLayer(node.input(0), node.output(0), node.name());
-    }
-}
-
 void OnnxParser::ParseActivation(const onnx::NodeProto& node, const armnn::ActivationFunction func)
 {
     CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1, 3);
@@ -1160,66 +1143,148 @@
     ParseActivation(node, ActivationFunction::LeakyReLu);
 }
 
-void OnnxParser::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const Convolution2dDescriptor& convDesc)
+void OnnxParser::ParseAdd(const onnx::NodeProto& node)
 {
-    ARMNN_ASSERT(node.op_type() == "Conv");
+    CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
+    CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
 
-    DepthwiseConvolution2dDescriptor desc;
-    desc.m_PadLeft      = convDesc.m_PadLeft;
-    desc.m_PadRight     = convDesc.m_PadRight;
-    desc.m_PadTop       = convDesc.m_PadTop;
-    desc.m_PadBottom    = convDesc.m_PadBottom;
-    desc.m_StrideX      = convDesc.m_StrideX;
-    desc.m_StrideY      = convDesc.m_StrideY;
-    desc.m_BiasEnabled  = convDesc.m_BiasEnabled;
+    VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
 
-    armnn::IConnectableLayer* layer;
-    auto weightTensor = CreateConstTensor(node.input(1));
-    TensorShape& weightShape = weightTensor.first.GetShape();
-    weightShape[1] = weightShape[0];
-    weightShape[0] = 1;
-    m_TensorsInfo[node.input(1)].m_info->SetShape(weightShape);
+    // TODO: unify broadcast validation code across layers
+    // tracked by: IVGCVSW-1576
 
-    if (node.input_size() == 3)
+    // Checking broadcast compatibility : only scalar or 1D tensors
+    auto inputs = AddPrepareBroadcast(node.input(0), node.input(1));
+    auto input0 = *m_TensorsInfo[inputs.first].m_info;
+    auto input1 = *m_TensorsInfo[inputs.second].m_info;
+    ARMNN_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions());
+
+    unsigned int numDims = input0.GetNumDimensions();
+    for (unsigned int i = 0; i < numDims; i++)
     {
-        if(!m_TensorsInfo[node.input(2)].isConstant())
+        unsigned int dim0 = input0.GetShape()[i];
+        unsigned int dim1 = input1.GetShape()[i];
+        if (dim0 != dim1 && dim0 != 1 && dim1 != 1)
         {
             throw ParseException(boost::str(
-                boost::format("Bias '%1%' should be constant in Conv layer '%2%' %3%")
-                              % node.input(2)
-                              % node.name()
-                              % CHECK_LOCATION().AsString()));
+                boost::format("Broadcast is only supported for scalar or 1D tensors in Add node '%1%'. "
+                              "Input dimensions should either match or one should be of size 1 and here, "
+                              "%2% and %3% %4%")
+                % node.name()
+                % TensorInfoAsString(*m_TensorsInfo[inputs.first].m_info, inputs.first,
+                                     m_TensorsInfo[inputs.first].m_dtype)
+                % TensorInfoAsString(*m_TensorsInfo[inputs.second].m_info, inputs.second,
+                                     m_TensorsInfo[inputs.second].m_dtype)
+                % CHECK_LOCATION().AsString()));
         }
-        desc.m_BiasEnabled = true;
-        auto biasTensor = CreateConstTensor(node.input(2));
-        layer = m_Network->AddDepthwiseConvolution2dLayer(desc,
-                                                          weightTensor.first,
-                                                          Optional<ConstTensor>(biasTensor.first),
-                                                          node.name().c_str());
     }
-    else
-    {
-        layer = m_Network->AddDepthwiseConvolution2dLayer(desc,
-                                                          weightTensor.first,
-                                                          EmptyOptional(),
-                                                          node.name().c_str());
-    }
+
+
+    IConnectableLayer* layer = m_Network->AddAdditionLayer(node.name().c_str());
     ARMNN_ASSERT(layer != nullptr);
 
     auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
-                                        { m_TensorsInfo[node.input(0)].m_info->GetShape(),
-                                          m_TensorsInfo[node.input(1)].m_info->GetShape() });
-
+                                        { m_TensorsInfo[inputs.first].m_info->GetShape(),
+                                          m_TensorsInfo[inputs.second].m_info->GetShape() });
     layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
 
-    // register the input connection slots for the layer, connections are made after all layers have been created
-    // only the tensors for the inputs are relevant, exclude the const tensors
-    RegisterInputSlots(layer, {node.input(0)});
+    // register the input connection -> for constant inputs, we need to make a newDim constant layer
+    if(m_TensorsInfo[inputs.first].isConstant()) {
+        CreateConstantLayer(inputs.first, boost::str(boost::format("Add:constant_of_%1%") % node.input(0)));
+    }
+    if(m_TensorsInfo[inputs.second].isConstant()) {
+        CreateConstantLayer(inputs.second, boost::str(boost::format("Add:constant_of_%1%") % node.input(1)));
+    }
+    RegisterInputSlots(layer, {inputs.first, inputs.second});
 
-    // register the output connection slots for the layer, connections are made after all layers have been created
+    // register the output connection
     RegisterOutputSlots(layer, {node.output(0)});
 }
 
+void OnnxParser::ParseAveragePool(const onnx::NodeProto& node)
+{
+    Pooling2dDescriptor desc;
+    desc.m_PoolType = PoolingAlgorithm::Average;
+
+    uint32_t count_include_pad = 0;
+    count_include_pad = ReadOptionalNodeUint32Attribute(node, "count_include_pad");
+    if(count_include_pad) {
+        desc.m_PaddingMethod = PaddingMethod::IgnoreValue;
+    }
+    AddPoolingLayer(node, desc);
+}
+
+void OnnxParser::ParseBatchNormalization(const onnx::NodeProto& node)
+{
+    //IGNORE momentum parameter and spatial parameters
+
+    CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 5);
+    CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
+
+    VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
+    for(int ind = 1; ind < node.input_size(); ++ind)
+    {
+        auto tensor = node.input(ind);
+        if(! m_TensorsInfo[tensor].isConstant())
+        {
+            throw ParseException(boost::str(
+                boost::format("Input tensor '%1%' should be constant in BatchNormalization node '%2%' %3%")
+                % tensor
+                % node.name()
+                % CHECK_LOCATION().AsString()));
+        }
+    }
+
+    float epsilon = ReadOptionalNodeFloatAttribute(node, "epsilon", 1e-5f);
+    BatchNormalizationDescriptor desc;
+    desc.m_Eps = epsilon;
+
+    auto scaleTensor = CreateConstTensor(node.input(1));
+    auto biasTensor = CreateConstTensor(node.input(2));
+    auto meanTensor = CreateConstTensor(node.input(3));
+    auto varTensor = CreateConstTensor(node.input(4));
+
+    IConnectableLayer* layer = m_Network->AddBatchNormalizationLayer(desc,
+                                                                     meanTensor.first,
+                                                                     varTensor.first,
+                                                                     biasTensor.first,
+                                                                     scaleTensor.first,
+                                                                     node.name().c_str());
+    ARMNN_ASSERT(layer != nullptr);
+
+    auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
+    layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
+
+    RegisterInputSlots(layer, {node.input(0)}); //don't register constant inputs
+
+    // register the output connection
+    RegisterOutputSlots(layer, {node.output(0)});
+}
+
+void OnnxParser::ParseConstant(const onnx::NodeProto& node)
+{
+    CHECK_VALID_SIZE(static_cast<size_t>(node.attribute_size()), 1);
+    if (!node.attribute(0).has_t())
+    {
+        throw ParseException(boost::str(
+              boost::format("Value not found for Constant node '%1%' %2%")
+              % node.name()
+              % CHECK_LOCATION().AsString()));
+    }
+    const onnx::TensorProto& onnxTensor = node.attribute(0).t();
+
+    //ONNX can have Float16 and double constant nodes but ArmNN only supports float32
+    CHECK_VALID_DATATYPE(node.name(), onnxTensor.name(),
+                         static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::FLOAT);
+
+    //Register this as a m_ConstParam so we know we can use it as a constant param in future layers.
+    m_TensorsInfo[node.output(0)].m_tensor = std::make_unique<const onnx::TensorProto>(onnxTensor);
+    m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(ToTensorInfo(onnxTensor));
+    m_TensorsInfo[node.output(0)].m_dtype = static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type());
+
+    CreateConstantLayer(node.output(0), node.name());
+}
+
 void OnnxParser::ParseConv(const onnx::NodeProto& node)
 {
     CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2, 3); //input, weight, (bias)
@@ -1231,10 +1296,10 @@
     {
         throw ParseException(boost::str(
             boost::format("ArmNN only supports 2D convolution and Conv layer '%1%' input %2% %3%")
-                          % node.name()
-                          % TensorInfoAsString(*m_TensorsInfo[node.input(0)].m_info, node.input(0),
-                                               m_TensorsInfo[node.input(0)].m_dtype)
-                          % CHECK_LOCATION().AsString()));
+            % node.name()
+            % TensorInfoAsString(*m_TensorsInfo[node.input(0)].m_info, node.input(0),
+                                 m_TensorsInfo[node.input(0)].m_dtype)
+            % CHECK_LOCATION().AsString()));
     }
 
     if(!m_TensorsInfo[node.input(1)].isConstant())
@@ -1262,9 +1327,9 @@
                 throw ParseException(boost::str(
                     boost::format("ArmNN only supports Convolution layers with dilations [1,1], and node '%1%' "
                                   "has dilatation %2% %3%")
-                                   % node.name()
-                                   % ss.str()
-                                   % CHECK_LOCATION().AsString()));
+                    % node.name()
+                    % ss.str()
+                    % CHECK_LOCATION().AsString()));
             }
         }
     }
@@ -1305,7 +1370,7 @@
             {
                 throw ParseException(boost::str(
                     boost::format("Invalid auto_pad attribute for node %1%. "
-                    "Only SAME_UPPER, SAME_LOWER or VALID supported and found %2% %3%")
+                                  "Only SAME_UPPER, SAME_LOWER or VALID supported and found %2% %3%")
                     % node.name()
                     % paddingString
                     % CHECK_LOCATION().AsString()));
@@ -1350,10 +1415,10 @@
                         "Error parsing Convolution node: %1%. "
                         "The 'group'=%2% parameter cannot be larger than the "
                         "channel of the input shape=%3% (in NCHW format). %4%") %
-                        node.name() %
-                        group %
-                        inputInfo.GetShape()[1] %
-                        CHECK_LOCATION().AsString()));
+                    node.name() %
+                    group %
+                    inputInfo.GetShape()[1] %
+                    CHECK_LOCATION().AsString()));
         }
         else if (group == inputInfo.GetShape()[1])
         {
@@ -1368,8 +1433,8 @@
             //  and concatenate the results afterwards
             throw ParseException(boost::str(
                 boost::format("Error parsing Convolution node: %1%. "
-                "The 'group'=%2% parameter should be 1 or be equal to the "
-                "channel of the input shape=%3% (in NCHW format). %4%") %
+                              "The 'group'=%2% parameter should be 1 or be equal to the "
+                              "channel of the input shape=%3% (in NCHW format). %4%") %
                 node.name() %
                 group %
                 inputInfo.GetShape()[1] %
@@ -1386,9 +1451,9 @@
         {
             throw ParseException(boost::str(
                 boost::format("Bias '%1%' should be constant in Conv layer '%2%' %3%")
-                              % node.input(2)
-                              % node.name()
-                              % CHECK_LOCATION().AsString()));
+                % node.input(2)
+                % node.name()
+                % CHECK_LOCATION().AsString()));
         }
         desc.m_BiasEnabled = true;
         auto biasTensor = CreateConstTensor(node.input(2));
@@ -1419,6 +1484,140 @@
     RegisterOutputSlots(layer, {node.output(0)});
 }
 
+void OnnxParser::ParseFlatten(const onnx::NodeProto& node)
+{
+    CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
+    CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
+
+    CHECK_VALID_DATATYPE(node.name(), node.input(0),
+                         m_TensorsInfo[node.input(0)].m_dtype,
+                         onnx::TensorProto::FLOAT);
+
+    int64_t axis = ReadOptionalNodeInt64Attribute(node, "axis", 1);
+    TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
+
+    /// Negative axis conversion
+    if (axis < 0)
+    {
+        axis += inputShape.GetNumDimensions();
+    }
+
+    /// Check Axis is within dimensions
+    if (axis < 0 || axis >= inputShape.GetNumDimensions())
+    {
+        throw ParseException( boost::str(
+            boost::format("Axis '%1%' invalid. Tensor has '%2%' dimensions in FlattenLayer '%3%'")
+            % axis % inputShape.GetNumDimensions() % node.name()));
+    }
+
+    /// If axis chosen is 0 dimension1 will always be 1 in output , default dimension2 to 1 because 0 is invalid
+    uint dimension1{1};
+    uint dimension2{1};
+    uint i{0};
+
+    /// dimension1 = (d_0 * d_1 ... d_(axis-1))
+    for (i = 0; i < axis; i++){
+        dimension1 *= inputShape[i];
+    }
+
+    /// dimension2 = (d_axis * d_(axis+1) ... d_n)
+    for (i = static_cast<uint>(axis); i < inputShape.GetNumDimensions(); i++){
+        dimension2 *= inputShape[i];
+    }
+
+    TensorShape outputShape{dimension1, dimension2};
+
+    auto outInfo = ComputeReshapeInfo(outputShape, inputShape, node.output(0));
+    m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
+    CreateReshapeLayer(node.input(0), node.output(0), node.name());
+}
+
+void OnnxParser::ParseGlobalAveragePool(const onnx::NodeProto& node)
+{
+    Pooling2dDescriptor desc = Pooling2dDescriptor();
+    desc.m_PoolType = PoolingAlgorithm::Average;
+
+    //kernel size is the same as input
+    TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
+    desc.m_PoolWidth  = inputShape[3];
+    desc.m_PoolHeight = inputShape[2];
+
+    IConnectableLayer* layer = m_Network->AddPooling2dLayer(desc, node.name().c_str());
+    ARMNN_ASSERT(layer != nullptr);
+
+    auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape});
+    layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
+
+    // register the input connection slots for the layer, connections are made after all layers have been created
+    // only the tensors for the inputs are relevant, exclude the const tensors
+    RegisterInputSlots(layer, {node.input(0)});
+
+    // register the output connection slots for the layer, connections are made after all layers have been created
+    RegisterOutputSlots(layer, {node.output(0)});
+}
+
+void OnnxParser::ParseMaxPool(const onnx::NodeProto& node)
+{
+    Pooling2dDescriptor desc;
+    desc.m_PoolType = PoolingAlgorithm::Max;
+    desc.m_PaddingMethod = PaddingMethod::Exclude;
+    AddPoolingLayer(node, desc);
+}
+
+void OnnxParser::ParseReshape(const onnx::NodeProto& node)
+{
+    CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
+    CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
+
+    CHECK_VALID_DATATYPE(node.name(), node.input(0),
+                         m_TensorsInfo[node.input(0)].m_dtype,
+                         onnx::TensorProto::FLOAT); //input
+    CHECK_VALID_DATATYPE(node.name(), node.input(1),
+                         m_TensorsInfo[node.input(1)].m_dtype,
+                         onnx::TensorProto::INT64); //shape
+
+    if(!m_TensorsInfo[node.input(1)].isConstant())
+    {
+        throw ParseException(boost::str(
+            boost::format("Shape '%1%' should be constant in Reshape layer '%2%' %3%")
+            % node.input(1)
+            % node.name()
+            % CHECK_LOCATION().AsString()));
+    }
+
+    if(m_TensorsInfo[node.input(0)].isConstant())
+    {
+        //make a new cst tensor -> move the data to the output tensor (the shape is already good in the output tensor)
+        if(m_TensorsInfo.count(node.output(0)) == 0)
+        {
+            m_TensorsInfo[node.output(0)] = OnnxTensor();
+        }
+        m_TensorsInfo[node.output(0)].m_tensor =
+            std::make_unique<onnx::TensorProto>(*m_TensorsInfo[node.input(0)].m_tensor);
+    }
+    else
+    {
+        TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
+
+        if(m_TensorsInfo.count(node.output(0)) == 0 || m_TensorsInfo[node.output(0)].m_info == nullptr)
+        {
+            uint64_t dims = static_cast<uint64_t>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size());
+            TensorShape targetShape{static_cast<unsigned int>(dims), 1};
+
+            for(uint i = 0; i < dims; i++)
+            {
+                int val = CHECKED_INT32(m_TensorsInfo[node.input(1)].m_tensor->int64_data(static_cast<int>(i)));
+                targetShape[i]= static_cast<unsigned int>(val);
+            }
+
+            auto outInfo = ComputeReshapeInfo(targetShape, inputShape, node.output(0));
+            m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
+        }
+
+        CreateReshapeLayer(node.input(0), node.output(0), node.name());
+    }
+}
+
 void OnnxParser::PrependForBroadcast(const std::string& outputName,
                                      const std::string& input0,
                                      const std::string& input1)
@@ -1458,134 +1657,6 @@
     }
 }
 
-std::pair<std::string, std::string> OnnxParser::AddPrepareBroadcast(const std::string& input0,
-                                                                    const std::string& input1)
-{
-    std::pair<std::string, std::string> inputs = std::make_pair(input0, input1);
-
-    TensorShape input0Shape = m_TensorsInfo[input0].m_info->GetShape();
-    TensorShape input1Shape = m_TensorsInfo[input1].m_info->GetShape();
-
-    if(input1Shape.GetNumDimensions() < input0Shape.GetNumDimensions())
-    {
-        auto outputName = boost::str(boost::format("reshape_output_%1%") % input1);
-        PrependForBroadcast(outputName, input1, input0);
-        inputs.second = outputName;
-    }
-    else if(input0Shape.GetNumDimensions() < input1Shape.GetNumDimensions())
-    {
-        auto outputName = boost::str(boost::format("reshape_output_%1%") % input0);
-        PrependForBroadcast(outputName, input0, input1);
-        inputs.first = outputName;
-    }
-    return inputs;
-}
-
-void OnnxParser::ParseAdd(const onnx::NodeProto& node)
-{
-    CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
-    CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
-
-    VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
-
-     // TODO: unify broadcast validation code across layers
-     // tracked by: IVGCVSW-1576
-
-     // Checking broadcast compatibility : only scalar or 1D tensors
-     auto inputs = AddPrepareBroadcast(node.input(0), node.input(1));
-     auto input0 = *m_TensorsInfo[inputs.first].m_info;
-     auto input1 = *m_TensorsInfo[inputs.second].m_info;
-     ARMNN_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions());
-
-     unsigned int numDims = input0.GetNumDimensions();
-     for (unsigned int i = 0; i < numDims; i++)
-     {
-         unsigned int dim0 = input0.GetShape()[i];
-         unsigned int dim1 = input1.GetShape()[i];
-         if (dim0 != dim1 && dim0 != 1 && dim1 != 1)
-         {
-             throw ParseException(boost::str(
-                 boost::format("Broadcast is only supported for scalar or 1D tensors in Add node '%1%'. "
-                               "Input dimensions should either match or one should be of size 1 and here, "
-                               "%2% and %3% %4%")
-                               % node.name()
-                               % TensorInfoAsString(*m_TensorsInfo[inputs.first].m_info, inputs.first,
-                                                    m_TensorsInfo[inputs.first].m_dtype)
-                               % TensorInfoAsString(*m_TensorsInfo[inputs.second].m_info, inputs.second,
-                                                    m_TensorsInfo[inputs.second].m_dtype)
-                               % CHECK_LOCATION().AsString()));
-         }
-     }
-
-
-     IConnectableLayer* layer = m_Network->AddAdditionLayer(node.name().c_str());
-     ARMNN_ASSERT(layer != nullptr);
-
-     auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
-                                         { m_TensorsInfo[inputs.first].m_info->GetShape(),
-                                           m_TensorsInfo[inputs.second].m_info->GetShape() });
-     layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
-
-     // register the input connection -> for constant inputs, we need to make a newDim constant layer
-     if(m_TensorsInfo[inputs.first].isConstant()) {
-         CreateConstantLayer(inputs.first, boost::str(boost::format("Add:constant_of_%1%") % node.input(0)));
-     }
-     if(m_TensorsInfo[inputs.second].isConstant()) {
-         CreateConstantLayer(inputs.second, boost::str(boost::format("Add:constant_of_%1%") % node.input(1)));
-     }
-     RegisterInputSlots(layer, {inputs.first, inputs.second});
-
-     // register the output connection
-     RegisterOutputSlots(layer, {node.output(0)});
-}
-
-void OnnxParser::ParseBatchNormalization(const onnx::NodeProto& node)
-{
-    //IGNORE momentum parameter and spatial parameters
-
-    CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 5);
-    CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
-
-    VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
-    for(int ind = 1; ind < node.input_size(); ++ind)
-    {
-        auto tensor = node.input(ind);
-        if(! m_TensorsInfo[tensor].isConstant())
-        {
-            throw ParseException(boost::str(
-                boost::format("Input tensor '%1%' should be constant in BatchNormalization node '%2%' %3%")
-                              % tensor
-                              % node.name()
-                              % CHECK_LOCATION().AsString()));
-        }
-    }
-
-    float epsilon = ReadOptionalNodeFloatAttribute(node, "epsilon", 1e-5f);
-    BatchNormalizationDescriptor desc;
-    desc.m_Eps = epsilon;
-
-    auto scaleTensor = CreateConstTensor(node.input(1));
-    auto biasTensor = CreateConstTensor(node.input(2));
-    auto meanTensor = CreateConstTensor(node.input(3));
-    auto varTensor = CreateConstTensor(node.input(4));
-
-    IConnectableLayer* layer = m_Network->AddBatchNormalizationLayer(desc,
-                                                                     meanTensor.first,
-                                                                     varTensor.first,
-                                                                     biasTensor.first,
-                                                                     scaleTensor.first,
-                                                                     node.name().c_str());
-    ARMNN_ASSERT(layer != nullptr);
-
-    auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
-    layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
-
-    RegisterInputSlots(layer, {node.input(0)}); //don't register constant inputs
-
-    // register the output connection
-    RegisterOutputSlots(layer, {node.output(0)});
-}
-
 void OnnxParser::SetupInputLayers()
 {
     //Find user input and add their layers
diff --git a/src/armnnOnnxParser/OnnxParser.hpp b/src/armnnOnnxParser/OnnxParser.hpp
index cc012ff..a87863e 100644
--- a/src/armnnOnnxParser/OnnxParser.hpp
+++ b/src/armnnOnnxParser/OnnxParser.hpp
@@ -89,22 +89,15 @@
     std::pair<std::string, std::string> AddPrepareBroadcast(const std::string& input0, const std::string& input1);
     void PrependForBroadcast(const std::string& outputName, const std::string& input0, const std::string& input1);
 
+    void AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const armnn::Convolution2dDescriptor& convDesc);
+    void AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode = nullptr);
+    void AddPoolingLayer(const onnx::NodeProto& nodeProto, armnn::Pooling2dDescriptor& desc);
+
     void CreateConstantLayer(const std::string& tensorName, const std::string& layerName);
     void CreateReshapeLayer(const std::string& inputName,
                             const std::string& outputName,
                             const std::string& layerName);
 
-    void ParseBatchNormalization(const onnx::NodeProto& node);
-    void ParseConstant(const onnx::NodeProto& nodeProto);
-
-    void ParseMaxPool(const onnx::NodeProto& nodeProto);
-    void ParseAveragePool(const onnx::NodeProto& nodeProto);
-    void ParseGlobalAveragePool(const onnx::NodeProto& node);
-
-    void AddPoolingLayer(const onnx::NodeProto& nodeProto, armnn::Pooling2dDescriptor& desc);
-
-    void ParseReshape(const onnx::NodeProto& nodeProto);
-
     void ParseActivation(const onnx::NodeProto& nodeProto, const armnn::ActivationFunction func);
     void ParseClip(const onnx::NodeProto& nodeProto);
     void ParseSigmoid(const onnx::NodeProto& nodeProto);
@@ -112,11 +105,15 @@
     void ParseRelu(const onnx::NodeProto& nodeProto);
     void ParseLeakyRelu(const onnx::NodeProto& nodeProto);
 
-    void AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const armnn::Convolution2dDescriptor& convDesc);
-    void ParseConv(const onnx::NodeProto& nodeProto);
-
     void ParseAdd(const onnx::NodeProto& nodeProto);
-    void AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode = nullptr);
+    void ParseAveragePool(const onnx::NodeProto& nodeProto);
+    void ParseBatchNormalization(const onnx::NodeProto& node);
+    void ParseConstant(const onnx::NodeProto& nodeProto);
+    void ParseConv(const onnx::NodeProto& nodeProto);
+    void ParseFlatten(const onnx::NodeProto& node);
+    void ParseGlobalAveragePool(const onnx::NodeProto& node);
+    void ParseMaxPool(const onnx::NodeProto& nodeProto);
+    void ParseReshape(const onnx::NodeProto& nodeProto);
 
     void RegisterInputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes);
     void RegisterOutputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes);
@@ -184,5 +181,6 @@
     };
 
     std::vector<UsageSummary> m_OutputsFusedAndUsed;
+
 };
 }
diff --git a/src/armnnOnnxParser/test/Flatten.cpp b/src/armnnOnnxParser/test/Flatten.cpp
new file mode 100644
index 0000000..1ba509e
--- /dev/null
+++ b/src/armnnOnnxParser/test/Flatten.cpp
@@ -0,0 +1,443 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnOnnxParser/IOnnxParser.hpp"
+#include  "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(OnnxParser)
+
+struct FlattenMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+    FlattenMainFixture(const std::string& dataType)
+    {
+        m_Prototext = R"(
+                   ir_version: 3
+                   producer_name:  "CNTK"
+                   producer_version:  "2.5.1"
+                   domain:  "ai.cntk"
+                   model_version: 1
+                   graph {
+                     name:  "CNTKGraph"
+                     input {
+                        name: "Input"
+                        type {
+                          tensor_type {
+                            elem_type: )" + dataType + R"(
+                            shape {
+                              dim {
+                                dim_value: 2
+                              }
+                              dim {
+                                dim_value: 2
+                              }
+                              dim {
+                                dim_value: 3
+                              }
+                              dim {
+                                dim_value: 3
+                              }
+                            }
+                          }
+                        }
+                      }
+                     node {
+                         input: "Input"
+                         output: "Output"
+                         name: "flatten"
+                         op_type: "Flatten"
+                         attribute {
+                           name: "axis"
+                           i: 2
+                           type: INT
+                         }
+                      }
+                      output {
+                          name: "Output"
+                          type {
+                             tensor_type {
+                               elem_type: 1
+                               shape {
+                                   dim {
+                                       dim_value: 4
+                                   }
+                                   dim {
+                                       dim_value: 9
+                                   }
+                               }
+                            }
+                          }
+                       }
+                    }
+                   opset_import {
+                      version: 7
+                    })";
+    }
+};
+
+struct FlattenDefaultAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+    FlattenDefaultAxisFixture(const std::string& dataType)
+    {
+        m_Prototext = R"(
+                   ir_version: 3
+                   producer_name:  "CNTK"
+                   producer_version:  "2.5.1"
+                   domain:  "ai.cntk"
+                   model_version: 1
+                   graph {
+                     name:  "CNTKGraph"
+                     input {
+                        name: "Input"
+                        type {
+                          tensor_type {
+                            elem_type: )" + dataType + R"(
+                            shape {
+                              dim {
+                                dim_value: 2
+                              }
+                              dim {
+                                dim_value: 2
+                              }
+                              dim {
+                                dim_value: 3
+                              }
+                              dim {
+                                dim_value: 3
+                              }
+                            }
+                          }
+                        }
+                      }
+                     node {
+                         input: "Input"
+                         output: "Output"
+                         name: "flatten"
+                         op_type: "Flatten"
+                      }
+                      output {
+                          name: "Output"
+                          type {
+                             tensor_type {
+                               elem_type: 1
+                               shape {
+                                   dim {
+                                       dim_value: 2
+                                   }
+                                   dim {
+                                       dim_value: 18
+                                   }
+                               }
+                            }
+                          }
+                       }
+                    }
+                   opset_import {
+                      version: 7
+                    })";
+    }
+};
+
+struct FlattenAxisZeroFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+    FlattenAxisZeroFixture(const std::string& dataType)
+    {
+        m_Prototext = R"(
+                   ir_version: 3
+                   producer_name:  "CNTK"
+                   producer_version:  "2.5.1"
+                   domain:  "ai.cntk"
+                   model_version: 1
+                   graph {
+                     name:  "CNTKGraph"
+                     input {
+                        name: "Input"
+                        type {
+                          tensor_type {
+                            elem_type: )" + dataType + R"(
+                            shape {
+                              dim {
+                                dim_value: 2
+                              }
+                              dim {
+                                dim_value: 2
+                              }
+                              dim {
+                                dim_value: 3
+                              }
+                              dim {
+                                dim_value: 3
+                              }
+                            }
+                          }
+                        }
+                      }
+                     node {
+                         input: "Input"
+                         output: "Output"
+                         name: "flatten"
+                         op_type: "Flatten"
+                         attribute {
+                           name: "axis"
+                           i: 0
+                           type: INT
+                         }
+                      }
+                      output {
+                          name: "Output"
+                          type {
+                             tensor_type {
+                               elem_type: 1
+                               shape {
+                                   dim {
+                                       dim_value: 1
+                                   }
+                                   dim {
+                                       dim_value: 36
+                                   }
+                               }
+                            }
+                          }
+                       }
+                    }
+                   opset_import {
+                      version: 7
+                    })";
+    }
+};
+
+struct FlattenNegativeAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+    FlattenNegativeAxisFixture(const std::string& dataType)
+    {
+        m_Prototext = R"(
+                   ir_version: 3
+                   producer_name:  "CNTK"
+                   producer_version:  "2.5.1"
+                   domain:  "ai.cntk"
+                   model_version: 1
+                   graph {
+                     name:  "CNTKGraph"
+                     input {
+                        name: "Input"
+                        type {
+                          tensor_type {
+                            elem_type: )" + dataType + R"(
+                            shape {
+                              dim {
+                                dim_value: 2
+                              }
+                              dim {
+                                dim_value: 2
+                              }
+                              dim {
+                                dim_value: 3
+                              }
+                              dim {
+                                dim_value: 3
+                              }
+                            }
+                          }
+                        }
+                      }
+                     node {
+                         input: "Input"
+                         output: "Output"
+                         name: "flatten"
+                         op_type: "Flatten"
+                         attribute {
+                           name: "axis"
+                           i: -1
+                           type: INT
+                         }
+                      }
+                      output {
+                          name: "Output"
+                          type {
+                             tensor_type {
+                               elem_type: 1
+                               shape {
+                                   dim {
+                                       dim_value: 12
+                                   }
+                                   dim {
+                                       dim_value: 3
+                                   }
+                               }
+                            }
+                          }
+                       }
+                    }
+                   opset_import {
+                      version: 7
+                    })";
+    }
+};
+
+struct FlattenInvalidNegativeAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+    FlattenInvalidNegativeAxisFixture(const std::string& dataType)
+    {
+        m_Prototext = R"(
+                   ir_version: 3
+                   producer_name:  "CNTK"
+                   producer_version:  "2.5.1"
+                   domain:  "ai.cntk"
+                   model_version: 1
+                   graph {
+                     name:  "CNTKGraph"
+                     input {
+                        name: "Input"
+                        type {
+                          tensor_type {
+                            elem_type: )" + dataType + R"(
+                            shape {
+                              dim {
+                                dim_value: 2
+                              }
+                              dim {
+                                dim_value: 2
+                              }
+                              dim {
+                                dim_value: 3
+                              }
+                              dim {
+                                dim_value: 3
+                              }
+                            }
+                          }
+                        }
+                      }
+                     node {
+                         input: "Input"
+                         output: "Output"
+                         name: "flatten"
+                         op_type: "Flatten"
+                         attribute {
+                           name: "axis"
+                           i: -5
+                           type: INT
+                         }
+                      }
+                      output {
+                          name: "Output"
+                          type {
+                             tensor_type {
+                               elem_type: 1
+                               shape {
+                                   dim {
+                                       dim_value: 12
+                                   }
+                                   dim {
+                                       dim_value: 3
+                                   }
+                               }
+                            }
+                          }
+                       }
+                    }
+                   opset_import {
+                      version: 7
+                    })";
+    }
+};
+
+struct FlattenValidFixture : FlattenMainFixture
+{
+    FlattenValidFixture() : FlattenMainFixture("1") {
+        Setup();
+    }
+};
+
+struct FlattenDefaultValidFixture : FlattenDefaultAxisFixture
+{
+    FlattenDefaultValidFixture() : FlattenDefaultAxisFixture("1") {
+        Setup();
+    }
+};
+
+struct FlattenAxisZeroValidFixture : FlattenAxisZeroFixture
+{
+    FlattenAxisZeroValidFixture() : FlattenAxisZeroFixture("1") {
+        Setup();
+    }
+};
+
+struct FlattenNegativeAxisValidFixture : FlattenNegativeAxisFixture
+{
+    FlattenNegativeAxisValidFixture() : FlattenNegativeAxisFixture("1") {
+        Setup();
+    }
+};
+
+struct FlattenInvalidFixture : FlattenMainFixture
+{
+    FlattenInvalidFixture() : FlattenMainFixture("10") { }
+};
+
+struct FlattenInvalidAxisFixture : FlattenInvalidNegativeAxisFixture
+{
+    FlattenInvalidAxisFixture() : FlattenInvalidNegativeAxisFixture("1") { }
+};
+
+BOOST_FIXTURE_TEST_CASE(ValidFlattenTest, FlattenValidFixture)
+{
+    RunTest<2>({{"Input",
+                          { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                            1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                            1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
+                {{"Output",
+                          { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                            1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                            1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
+}
+
+BOOST_FIXTURE_TEST_CASE(ValidFlattenDefaultTest, FlattenDefaultValidFixture)
+{
+    RunTest<2>({{"Input",
+                    { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                        1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                        1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
+               {{"Output",
+                    { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                        1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                        1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
+}
+
+BOOST_FIXTURE_TEST_CASE(ValidFlattenAxisZeroTest, FlattenAxisZeroValidFixture)
+{
+    RunTest<2>({{"Input",
+                    { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                        1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                        1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
+               {{"Output",
+                    { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                        1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                        1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
+}
+
+BOOST_FIXTURE_TEST_CASE(ValidFlattenNegativeAxisTest, FlattenNegativeAxisValidFixture)
+{
+    RunTest<2>({{"Input",
+                    { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                        1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                        1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
+               {{"Output",
+                    { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                        1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                        1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
+}
+
+BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeFlatten, FlattenInvalidFixture)
+{
+    BOOST_CHECK_THROW(Setup(), armnn::ParseException);
+}
+
+BOOST_FIXTURE_TEST_CASE(IncorrectAxisFlatten, FlattenInvalidAxisFixture)
+{
+    BOOST_CHECK_THROW(Setup(), armnn::ParseException);
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnOnnxParser/test/Reshape.cpp b/src/armnnOnnxParser/test/Reshape.cpp
index 60937f0..119a406 100644
--- a/src/armnnOnnxParser/test/Reshape.cpp
+++ b/src/armnnOnnxParser/test/Reshape.cpp
@@ -85,6 +85,91 @@
     }
 };
 
+struct ReshapeRank4Fixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+    ReshapeRank4Fixture(const std::string& dataType)
+    {
+        m_Prototext = R"(
+                   ir_version: 3
+                   producer_name:  "CNTK"
+                   producer_version:  "2.5.1"
+                   domain:  "ai.cntk"
+                   model_version: 1
+                   graph {
+                     name:  "CNTKGraph"
+                     input {
+                        name: "Input"
+                        type {
+                          tensor_type {
+                            elem_type: )" + dataType + R"(
+                            shape {
+                              dim {
+                                dim_value: 2
+                              }
+                              dim {
+                                dim_value: 2
+                              }
+                              dim {
+                                dim_value: 3
+                              }
+                              dim {
+                                dim_value: 3
+                              }
+                            }
+                          }
+                        }
+                      }
+                      input {
+                         name: "Shape"
+                         type {
+                           tensor_type {
+                             elem_type: 7
+                             shape {
+                               dim {
+                                 dim_value: 2
+                               }
+                             }
+                           }
+                         }
+                       }
+                     node {
+                         input: "Input"
+                         input: "Shape"
+                         output: "Output"
+                         name: "reshape"
+                         op_type: "Reshape"
+
+                      }
+                      initializer {
+                        dims: 2
+                        data_type: 7
+                        int64_data: 2
+                        int64_data: 2
+                        name: "Shape"
+                     }
+                      output {
+                          name: "Output"
+                          type {
+                             tensor_type {
+                               elem_type: 1
+                               shape {
+                                   dim {
+                                       dim_value: 6
+                                   }
+                                   dim {
+                                       dim_value: 6
+                                   }
+                               }
+                            }
+                          }
+                       }
+                    }
+                   opset_import {
+                      version: 7
+                    })";
+    }
+};
+
 struct ReshapeValidFixture : ReshapeMainFixture
 {
     ReshapeValidFixture() : ReshapeMainFixture("1") {
@@ -92,6 +177,13 @@
     }
 };
 
+struct ReshapeValidRank4Fixture : ReshapeRank4Fixture
+{
+    ReshapeValidRank4Fixture() : ReshapeRank4Fixture("1") {
+        Setup();
+    }
+};
+
 struct ReshapeInvalidFixture : ReshapeMainFixture
 {
     ReshapeInvalidFixture() : ReshapeMainFixture("10") { }
@@ -102,6 +194,19 @@
     RunTest<2>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f }}}, {{"Output", { 0.0f, 1.0f, 2.0f, 3.0f }}});
 }
 
+BOOST_FIXTURE_TEST_CASE(ValidRank4ReshapeTest, ReshapeValidRank4Fixture)
+{
+    RunTest<2>(
+        {{"Input",
+                   {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                    1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                    1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}},
+        {{"Output",
+                    {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                     1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+                     1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}});
+}
+
 BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeReshape, ReshapeInvalidFixture)
 {
    BOOST_CHECK_THROW(Setup(), armnn::ParseException);