IVGCVSW-4906 Add front-end support for FILL operator
* Added new fill layer
* Added visitor tests
Signed-off-by: Ryan OShea <Ryan.OShea2@arm.com>
Change-Id: Iea677014866b4f2d514004623f59ee83f3c0eef8
Signed-off-by: Keith Davis <keith.davis@arm.com>
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 5d0990e..653e647 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -711,6 +711,25 @@
float m_Max;
};
+/// A FillDescriptor for the FillLayer
+struct FillDescriptor
+{
+ FillDescriptor()
+ : m_Value(0)
+ {}
+
+ FillDescriptor(const float& value)
+ : m_Value(value)
+ {}
+
+ bool operator ==(const FillDescriptor& rhs) const
+ {
+ return m_Value == rhs.m_Value;
+ }
+
+ float m_Value;
+};
+
/// A ResizeBilinearDescriptor for the ResizeBilinearLayer.
struct ResizeBilinearDescriptor
{
diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp
index 1c813b5..e31fb96 100644
--- a/include/armnn/DescriptorsFwd.hpp
+++ b/include/armnn/DescriptorsFwd.hpp
@@ -18,6 +18,7 @@
struct DetectionPostProcessDescriptor;
struct ElementwiseUnaryDescriptor;
struct FakeQuantizationDescriptor;
+struct FillDescriptor;
struct FullyConnectedDescriptor;
struct InstanceNormalizationDescriptor;
struct L2NormalizationDescriptor;
diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp
index 58509c9..33389eb 100644
--- a/include/armnn/ILayerSupport.hpp
+++ b/include/armnn/ILayerSupport.hpp
@@ -157,6 +157,11 @@
const FakeQuantizationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ virtual bool IsFillSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const FillDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+
virtual bool IsFloorSupported(const TensorInfo& input,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp
index 530e74f..aa5bdba 100644
--- a/include/armnn/ILayerVisitor.hpp
+++ b/include/armnn/ILayerVisitor.hpp
@@ -184,6 +184,14 @@
virtual void VisitEqualLayer(const IConnectableLayer* layer,
const char* name = nullptr) = 0;
+ /// Function a fill layer should call back to when its Accept(ILayerVisitor&) function is invoked.
+ /// @param layer - pointer to the layer which is calling back to this visit function.
+ /// @param fillDescriptor - Description of the layer
+ /// @param name - Optional name for the layer.
+ virtual void VisitFillLayer(const IConnectableLayer* layer,
+ const FillDescriptor& fillDescriptor,
+ const char* name = nullptr) = 0;
+
/// Function a floor layer should call back to when its Accept(ILayerVisitor&) function is invoked.
/// @param layer - pointer to the layer which is calling back to this visit function.
/// @param name - Optional name for the layer.
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index 1dd949d..ade6c52 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -213,10 +213,17 @@
/// Add an ElementwiseUnary layer to the network.
/// @param name - Optional name for the layer.
/// @param desc - Descriptor for the elementwiseUnary operation.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
virtual IConnectableLayer* AddElementwiseUnaryLayer(const ElementwiseUnaryDescriptor& elementwiseUnaryDescriptor,
const char* name = nullptr) = 0;
+ /// Add an Fill layer to the network.
+ /// @param name - Optional name for the layer.
+ /// @param fillDescriptor - Descriptor for the fill operation.
+ /// @return - Interface for configuring the layer.
+ virtual IConnectableLayer* AddFillLayer(const FillDescriptor& fillDescriptor,
+ const char* name = nullptr) = 0;
+
/// Adds a fully connected layer to the network.
/// @param fullyConnectedDescriptor - Description of the fully connected layer.
/// @param weights - Tensor for the weights data.
diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp
index 95d6bd3..0dc5e54 100644
--- a/include/armnn/LayerVisitorBase.hpp
+++ b/include/armnn/LayerVisitorBase.hpp
@@ -101,6 +101,10 @@
void VisitEqualLayer(const IConnectableLayer*,
const char*) override { DefaultPolicy::Apply(__func__); }
+ void VisitFillLayer(const IConnectableLayer*,
+ const FillDescriptor&,
+ const char*) override { DefaultPolicy::Apply(__func__); }
+
void VisitFloorLayer(const IConnectableLayer*,
const char*) override { DefaultPolicy::Apply(__func__); }