IVGCVSW-2367 Add Equal Operator to TfParser
* Unit tests in Equal.cpp
* Fixed error in Network::AddEqualLayer
* Refactored TfParser::Minimum/Equal to get rid of duplicate code
Change-Id: I0ed6f888eb391c995b88be20dc0c1b916dd14c3c
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 74742a9..b646437 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -353,6 +353,7 @@
{ "AvgPool", &TfParser::ParseAvgPool },
{ "Maximum", &TfParser::ParseMaximum },
{ "Minimum", &TfParser::ParseMinimum },
+ { "Equal", &TfParser::ParseEqual },
{ "Pad", &TfParser::ParsePad },
{ "Sub", &TfParser::ParseSub },
};
@@ -1530,8 +1531,8 @@
}
}
-ParsedTfOperationPtr TfParser::ParseMinimum(const tensorflow::NodeDef& nodeDef,
- const tensorflow::GraphDef& graphDef)
+std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> TfParser::ProcessElementwiseInputSlots(
+ const tensorflow::NodeDef& nodeDef, const std::string& layerName)
{
std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
@@ -1555,15 +1556,22 @@
else
{
throw ParseException(
- boost::str(
- boost::format("Unsupported broadcast configuration for Minimum operation %1% %2%")
- % nodeDef.name()
- % CHECK_LOCATION().AsString()));
+ boost::str(
+ boost::format("Unsupported broadcast configuration for %1% operation %2% %3%")
+ % layerName
+ % nodeDef.name()
+ % CHECK_LOCATION().AsString()));
}
}
+ return {input0Slot, input1Slot};
+}
- IConnectableLayer* const layer = m_Network->AddMinimumLayer(nodeDef.name().c_str());
-
+ParsedTfOperationPtr TfParser::ProcessElementwiseLayer(
+ IOutputSlot* input0Slot,
+ IOutputSlot* input1Slot,
+ IConnectableLayer* const layer,
+ const tensorflow::NodeDef& nodeDef)
+{
input0Slot->Connect(layer->GetInputSlot(0));
input1Slot->Connect(layer->GetInputSlot(1));
@@ -1584,6 +1592,30 @@
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
}
+ParsedTfOperationPtr TfParser::ParseEqual(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> inputLayers = ProcessElementwiseInputSlots(nodeDef, "Equal");
+ IOutputSlot* input0Slot = inputLayers.first;
+ IOutputSlot* input1Slot = inputLayers.second;
+
+ IConnectableLayer* const layer = m_Network->AddEqualLayer(nodeDef.name().c_str());
+
+ return ProcessElementwiseLayer(input0Slot, input1Slot, layer, nodeDef);
+}
+
+ParsedTfOperationPtr TfParser::ParseMinimum(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> inputLayers = ProcessElementwiseInputSlots(nodeDef, "Minimum");
+ IOutputSlot* input0Slot = inputLayers.first;
+ IOutputSlot* input1Slot = inputLayers.second;
+
+ IConnectableLayer* const layer = m_Network->AddMinimumLayer(nodeDef.name().c_str());
+
+ return ProcessElementwiseLayer(input0Slot, input1Slot, layer, nodeDef);
+}
+
ParsedTfOperationPtr TfParser::ParseSub(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
{
std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);