IVGCVSW-2914 Add Switch Layer and no-op factory method

Change-Id: I6a6ece708a49e8a97c83a3e7fec11c88af1e1cfa
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
diff --git a/src/armnn/InternalTypes.cpp b/src/armnn/InternalTypes.cpp
index 93a4f94..a811706 100644
--- a/src/armnn/InternalTypes.cpp
+++ b/src/armnn/InternalTypes.cpp
@@ -57,6 +57,7 @@
         case LayerType::Splitter: return "Splitter";
         case LayerType::StridedSlice: return "StridedSlice";
         case LayerType::Subtraction: return "Subtraction";
+        case LayerType::Switch: return "Switch";
         default:
             BOOST_ASSERT_MSG(false, "Unknown layer type");
             return "Unknown";
diff --git a/src/armnn/InternalTypes.hpp b/src/armnn/InternalTypes.hpp
index 7c7c601..5765b5b 100644
--- a/src/armnn/InternalTypes.hpp
+++ b/src/armnn/InternalTypes.hpp
@@ -57,9 +57,10 @@
     SpaceToBatchNd,
     Splitter,
     StridedSlice,
+    Subtraction,
     // Last layer goes here.
     LastLayer,
-    Subtraction = LastLayer
+    Switch = LastLayer
 };
 
 const char* GetLayerTypeAsCString(LayerType type);
diff --git a/src/armnn/LayerSupport.cpp b/src/armnn/LayerSupport.cpp
index bc6eec8..320d9ce 100644
--- a/src/armnn/LayerSupport.cpp
+++ b/src/armnn/LayerSupport.cpp
@@ -530,4 +530,15 @@
     FORWARD_LAYER_SUPPORT_FUNC(backend, IsSubtractionSupported, input0, input1, output);
 }
 
+bool IsSwitchSupported(const BackendId& backend,
+                       const TensorInfo& input0,
+                       const TensorInfo& input1,
+                       const TensorInfo& output0,
+                       const TensorInfo& output1,
+                       char* reasonIfUnsupported,
+                       size_t reasonIfUnsupportedMaxLength)
+{
+    FORWARD_LAYER_SUPPORT_FUNC(backend, IsSwitchSupported, input0, input1, output0, output1);
+}
+
 } // namespace armnn
diff --git a/src/armnn/LayersFwd.hpp b/src/armnn/LayersFwd.hpp
index 0bd68e0..31cfa66 100644
--- a/src/armnn/LayersFwd.hpp
+++ b/src/armnn/LayersFwd.hpp
@@ -50,6 +50,7 @@
 #include "layers/SplitterLayer.hpp"
 #include "layers/StridedSliceLayer.hpp"
 #include "layers/SubtractionLayer.hpp"
+#include "layers/SwitchLayer.hpp"
 
 namespace armnn
 {
@@ -122,5 +123,6 @@
 DECLARE_LAYER(Splitter)
 DECLARE_LAYER(StridedSlice)
 DECLARE_LAYER(Subtraction)
+DECLARE_LAYER(Switch)
 
 }
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 73db2e8..c1462c0 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -971,6 +971,11 @@
     return m_Graph->AddLayer<MergeLayer>(name);
 }
 
+IConnectableLayer* Network::AddSwitchLayer(const char* name)
+{
+    return m_Graph->AddLayer<SwitchLayer>(name);
+}
+
 void Network::Accept(ILayerVisitor& visitor) const
 {
     for (auto layer : GetGraph())
diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp
index bb7b9eb..660ca87 100644
--- a/src/armnn/Network.hpp
+++ b/src/armnn/Network.hpp
@@ -176,6 +176,8 @@
 
     IConnectableLayer* AddMergeLayer(const char* name = nullptr) override;
 
+    IConnectableLayer* AddSwitchLayer(const char* name = nullptr) override;
+
     void Accept(ILayerVisitor& visitor) const override;
 
 private:
diff --git a/src/armnn/layers/SwitchLayer.cpp b/src/armnn/layers/SwitchLayer.cpp
new file mode 100644
index 0000000..eae6e0d
--- /dev/null
+++ b/src/armnn/layers/SwitchLayer.cpp
@@ -0,0 +1,60 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include "SwitchLayer.hpp"
+
+#include "LayerCloneBase.hpp"
+
+#include <backendsCommon/WorkloadData.hpp>
+#include <backendsCommon/WorkloadFactory.hpp>
+
+namespace armnn
+{
+
+SwitchLayer::SwitchLayer(const char* name)
+    : Layer(2, 2, LayerType::Switch, name)
+{}
+
+std::unique_ptr<IWorkload> SwitchLayer::CreateWorkload(const Graph& graph,
+                                                       const IWorkloadFactory& factory) const
+{
+    SwitchQueueDescriptor descriptor;
+    return factory.CreateSwitch(descriptor, PrepInfoAndDesc(descriptor, graph));
+}
+
+SwitchLayer* SwitchLayer::Clone(Graph& graph) const
+{
+    return CloneBase<SwitchLayer>(graph, GetName());
+}
+
+void SwitchLayer::ValidateTensorShapesFromInputs()
+{
+    VerifyLayerConnections(2, CHECK_LOCATION());
+
+    BOOST_ASSERT_MSG(GetNumOutputSlots() == 2, "SwitchLayer: The layer should return 2 outputs.");
+
+    // Assuming first input is the Input and second input is the Constant
+    std::vector<TensorShape> inferredShapes = InferOutputShapes({
+        GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
+        GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape() });
+
+    BOOST_ASSERT(inferredShapes.size() == 1);
+
+    ConditionalThrowIfNotEqual<LayerValidationException>(
+        "SwitchLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
+        GetOutputSlot(0).GetTensorInfo().GetShape(),
+        inferredShapes[0]);
+
+    ConditionalThrowIfNotEqual<LayerValidationException>(
+        "SwitchLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
+        GetOutputSlot(1).GetTensorInfo().GetShape(),
+        inferredShapes[0]);
+}
+
+void SwitchLayer::Accept(ILayerVisitor& visitor) const
+{
+    visitor.VisitSwitchLayer(this, GetName());
+}
+
+} // namespace armnn
diff --git a/src/armnn/layers/SwitchLayer.hpp b/src/armnn/layers/SwitchLayer.hpp
new file mode 100644
index 0000000..bfda8c2
--- /dev/null
+++ b/src/armnn/layers/SwitchLayer.hpp
@@ -0,0 +1,42 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include "Layer.hpp"
+
+namespace armnn
+{
+
+/// This layer calculates both true and false outputs for input.
+class SwitchLayer : public Layer
+{
+public:
+    /// Makes a workload for the Switch type.
+    /// @param [in] graph The graph where this layer can be found.
+    /// @param [in] factory The workload factory which will create the workload.
+    /// @return A pointer to the created workload, or nullptr if not created.
+    virtual std::unique_ptr<IWorkload> CreateWorkload(const Graph& graph,
+                                                      const IWorkloadFactory& factory) const override;
+
+    /// Creates a dynamically-allocated copy of this layer.
+    /// @param [in] graph The graph into which this layer is being cloned.
+    SwitchLayer* Clone(Graph& graph) const override;
+
+    /// Check if the input tensor shape(s)
+    /// will lead to a valid configuration of @ref SwitchLayer.
+    void ValidateTensorShapesFromInputs() override;
+
+    void Accept(ILayerVisitor& visitor) const override;
+
+protected:
+    /// Constructor to create a SwitchLayer.
+    /// @param [in] name Optional name for the layer.
+    SwitchLayer(const char* name);
+
+    /// Default destructor
+    ~SwitchLayer() = default;
+};
+
+} // namespace armnn
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 09cdd7c..076072e 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -222,6 +222,7 @@
     m_ParserFunctions[Layer_SplitterLayer]               = &Deserializer::ParseSplitter;
     m_ParserFunctions[Layer_StridedSliceLayer]           = &Deserializer::ParseStridedSlice;
     m_ParserFunctions[Layer_SubtractionLayer]            = &Deserializer::ParseSubtraction;
+    m_ParserFunctions[Layer_SwitchLayer]                 = &Deserializer::ParseSwitch;
 }
 
 Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex)
@@ -306,6 +307,8 @@
             return graphPtr->layers()->Get(layerIndex)->layer_as_StridedSliceLayer()->base();
         case Layer::Layer_SubtractionLayer:
             return graphPtr->layers()->Get(layerIndex)->layer_as_SubtractionLayer()->base();
+        case Layer::Layer_SwitchLayer:
+            return graphPtr->layers()->Get(layerIndex)->layer_as_SwitchLayer()->base();
         case Layer::Layer_NONE:
         default:
             throw ParseException(boost::str(
@@ -2108,4 +2111,27 @@
     RegisterOutputSlots(graph, layerIndex, layer);
 }
 
+void Deserializer::ParseSwitch(GraphPtr graph, unsigned int layerIndex)
+{
+    CHECK_LAYERS(graph, 0, layerIndex);
+    auto inputs = GetInputs(graph, layerIndex);
+    CHECK_LOCATION();
+    CHECK_VALID_SIZE(inputs.size(), 2);
+
+    auto outputs = GetOutputs(graph, layerIndex);
+    CHECK_VALID_SIZE(outputs.size(), 2);
+
+    auto layerName = GetLayerName(graph, layerIndex);
+    IConnectableLayer* layer = m_Network->AddSwitchLayer(layerName.c_str());
+
+    armnn::TensorInfo output0TensorInfo = ToTensorInfo(outputs[0]);
+    layer->GetOutputSlot(0).SetTensorInfo(output0TensorInfo);
+
+    armnn::TensorInfo output1TensorInfo = ToTensorInfo(outputs[1]);
+    layer->GetOutputSlot(1).SetTensorInfo(output1TensorInfo);
+
+    RegisterInputSlots(graph, layerIndex, layer);
+    RegisterOutputSlots(graph, layerIndex, layer);
+}
+
 } // namespace armnnDeserializer
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index df983d9..dfa5b06 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -114,6 +114,7 @@
     void ParseSplitter(GraphPtr graph, unsigned int layerIndex);
     void ParseStridedSlice(GraphPtr graph, unsigned int layerIndex);
     void ParseSubtraction(GraphPtr graph, unsigned int layerIndex);
+    void ParseSwitch(GraphPtr graph, unsigned int layerIndex);
 
     void RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex, armnn::IOutputSlot* slot);
     void RegisterInputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IInputSlot* slot);
diff --git a/src/armnnDeserializer/DeserializerSupport.md b/src/armnnDeserializer/DeserializerSupport.md
index 4e5610c..770f7a8 100644
--- a/src/armnnDeserializer/DeserializerSupport.md
+++ b/src/armnnDeserializer/DeserializerSupport.md
@@ -41,5 +41,6 @@
 * Splitter
 * StridedSlice
 * Subtraction
+* Switch
 
 More machine learning layers will be supported in future releases.
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index 8b275b6..e8d72fc 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -119,7 +119,8 @@
     Lstm = 34,
     Quantize = 35,
     Dequantize = 36,
-    Merge = 37
+    Merge = 37,
+    Switch = 38
 }
 
 // Base layer table to be used as part of other layers
@@ -529,6 +530,10 @@
     base:LayerBase;
 }
 
+table SwitchLayer {
+    base:LayerBase;
+}
+
 union Layer {
     ActivationLayer,
     AdditionLayer,
@@ -567,7 +572,8 @@
     LstmLayer,
     QuantizeLayer,
     DequantizeLayer,
-    MergeLayer
+    MergeLayer,
+    SwitchLayer
 }
 
 table AnyLayer {
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index fe30c3e..74d0c43 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -893,6 +893,14 @@
     CreateAnyLayer(fbSubtractionLayer.o, serializer::Layer::Layer_SubtractionLayer);
 }
 
+void SerializerVisitor::VisitSwitchLayer(const armnn::IConnectableLayer* layer, const char* name)
+{
+    auto fbSwitchBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Switch);
+    auto fbSwitchLayer = serializer::CreateSwitchLayer(m_flatBufferBuilder, fbSwitchBaseLayer);
+
+    CreateAnyLayer(fbSwitchLayer.o, serializer::Layer::Layer_SwitchLayer);
+}
+
 fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer,
                                                                      const serializer::LayerType layerType)
 {
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index 775df83..4a71837 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -191,6 +191,9 @@
 
     void VisitSubtractionLayer(const armnn::IConnectableLayer* layer,
                                const char* name = nullptr) override;
+
+    void VisitSwitchLayer(const armnn::IConnectableLayer* layer,
+                          const char* name = nullptr) override;
 private:
 
     /// Creates the Input Slots and Output Slots and LayerBase for the layer.
diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md
index a8335e1..5b54bfd 100644
--- a/src/armnnSerializer/SerializerSupport.md
+++ b/src/armnnSerializer/SerializerSupport.md
@@ -41,5 +41,6 @@
 * Splitter
 * StridedSlice
 * Subtraction
+* Switch
 
 More machine learning layers will be supported in future releases.
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index a1ef9ee..2724ba4 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -2113,6 +2113,56 @@
     deserializedNetwork->Accept(verifier);
 }
 
+BOOST_AUTO_TEST_CASE(SerializeSwitch)
+{
+    class SwitchLayerVerifier : public LayerVerifierBase
+    {
+    public:
+        SwitchLayerVerifier(const std::string& layerName,
+                                 const std::vector<armnn::TensorInfo>& inputInfos,
+                                 const std::vector<armnn::TensorInfo>& outputInfos)
+            : LayerVerifierBase(layerName, inputInfos, outputInfos) {}
+
+        void VisitSwitchLayer(const armnn::IConnectableLayer* layer, const char* name) override
+        {
+            VerifyNameAndConnections(layer, name);
+        }
+
+        void VisitConstantLayer(const armnn::IConnectableLayer* layer,
+                                const armnn::ConstTensor& input,
+                                const char *name) override {}
+    };
+
+    const std::string layerName("switch");
+    const armnn::TensorInfo info({ 1, 4 }, armnn::DataType::Float32);
+
+    std::vector<float> constantData = GenerateRandomData<float>(info.GetNumElements());
+    armnn::ConstTensor constTensor(info, constantData);
+
+    armnn::INetworkPtr network = armnn::INetwork::Create();
+    armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
+    armnn::IConnectableLayer* const constantLayer = network->AddConstantLayer(constTensor, "constant");
+    armnn::IConnectableLayer* const switchLayer = network->AddSwitchLayer(layerName.c_str());
+    armnn::IConnectableLayer* const trueOutputLayer = network->AddOutputLayer(0);
+    armnn::IConnectableLayer* const falseOutputLayer = network->AddOutputLayer(1);
+
+    inputLayer->GetOutputSlot(0).Connect(switchLayer->GetInputSlot(0));
+    constantLayer->GetOutputSlot(0).Connect(switchLayer->GetInputSlot(1));
+    switchLayer->GetOutputSlot(0).Connect(trueOutputLayer->GetInputSlot(0));
+    switchLayer->GetOutputSlot(1).Connect(falseOutputLayer->GetInputSlot(0));
+
+    inputLayer->GetOutputSlot(0).SetTensorInfo(info);
+    constantLayer->GetOutputSlot(0).SetTensorInfo(info);
+    switchLayer->GetOutputSlot(0).SetTensorInfo(info);
+    switchLayer->GetOutputSlot(1).SetTensorInfo(info);
+
+    armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+    BOOST_CHECK(deserializedNetwork);
+
+    SwitchLayerVerifier verifier(layerName, {info, info}, {info, info});
+    deserializedNetwork->Accept(verifier);
+}
+
 BOOST_AUTO_TEST_CASE(SerializeDeserializeNonLinearNetwork)
 {
     class ConstantLayerVerifier : public LayerVerifierBase
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp
index fc2d502..6cad7b9 100644
--- a/src/backends/backendsCommon/LayerSupportBase.cpp
+++ b/src/backends/backendsCommon/LayerSupportBase.cpp
@@ -397,4 +397,13 @@
     return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
 }
 
+bool LayerSupportBase::IsSwitchSupported(const TensorInfo& input0,
+                                         const TensorInfo& input1,
+                                         const TensorInfo& output0,
+                                         const TensorInfo& output1,
+                                         Optional<std::string&> reasonIfUnsupported) const
+{
+    return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
+}
+
 } // namespace armnn
diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp
index 7c38b67..3c39f89 100644
--- a/src/backends/backendsCommon/LayerSupportBase.hpp
+++ b/src/backends/backendsCommon/LayerSupportBase.hpp
@@ -246,6 +246,12 @@
                                 const TensorInfo& input1,
                                 const TensorInfo& output,
                                 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
+    bool IsSwitchSupported(const TensorInfo& input0,
+                           const TensorInfo& input1,
+                           const TensorInfo& output0,
+                           const TensorInfo& output1,
+                           Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
 };
 
 } // namespace armnn
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 348c864..b850a65 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -75,45 +75,23 @@
 }
 
 //---------------------------------------------------------------
-void ValidateNoInputs(const WorkloadInfo& workloadInfo, std::string const& descName)
+void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
 {
-    if (workloadInfo.m_InputTensorInfos.size() != 0)
+    if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
     {
         throw InvalidArgumentException(descName +
-            ": Requires no inputs. " +
-            to_string(workloadInfo.m_InputTensorInfos.size()) + " has been provided.");
-    }
-}
-
-//---------------------------------------------------------------
-void ValidateSingleInput(const WorkloadInfo& workloadInfo, std::string const& descName)
-{
-    if (workloadInfo.m_InputTensorInfos.size() != 1)
-    {
-        throw InvalidArgumentException(descName +
-                                       ": Requires exactly one input. " +
-                                       to_string(workloadInfo.m_InputTensorInfos.size()) + " has been provided." );
-    }
-}
-
-//---------------------------------------------------------------
-void ValidateTwoInputs(const WorkloadInfo& workloadInfo, std::string const& descName)
-{
-    if (workloadInfo.m_InputTensorInfos.size() != 2)
-    {
-        throw InvalidArgumentException(descName +
-                                       ": Requires exactly two workloadInfo.m_InputTensorInfos. " +
+                                       ": Requires exactly " + to_string(expectedSize) + "input(s). " +
                                        to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
     }
 }
 
 //---------------------------------------------------------------
-void ValidateSingleOutput(const WorkloadInfo& workloadInfo, std::string const& descName)
+void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
 {
-    if (workloadInfo.m_OutputTensorInfos.size() != 1)
+    if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
     {
         throw InvalidArgumentException(descName +
-                                       ": Requires exactly one output. " +
+                                       ": Requires exactly " + to_string(expectedSize) + " output(s). " +
                                        to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
     }
 }
@@ -242,6 +220,18 @@
     }
 }
 
+//---------------------------------------------------------------
+void ValidateDataTypes(const TensorInfo& info,
+                       const std::vector<armnn::DataType>& supportedTypes,
+                       std::string const& descName)
+{
+    auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
+    if (iterator == supportedTypes.end())
+    {
+        throw InvalidArgumentException(descName  + ": " + " Tensor type is not supported.");
+    }
+}
+
 } //namespace
 
 void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
@@ -254,8 +244,8 @@
 //---------------------------------------------------------------
 void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "MemCopyQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "MemCopyQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "MemCopyQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "MemCopyQueueDescriptor" , 1);
 
     if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
     {
@@ -299,8 +289,8 @@
 //---------------------------------------------------------------
 void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "ActivationQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "ActivationQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "ActivationQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "ActivationQueueDescriptor", 1);
     ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
                               workloadInfo.m_OutputTensorInfos[0],
                               "ActivationQueueDescriptor",
@@ -311,8 +301,8 @@
 //---------------------------------------------------------------
 void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "SoftmaxQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "SoftmaxQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "SoftmaxQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "SoftmaxQueueDescriptor", 1);
 
     ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
                               workloadInfo.m_OutputTensorInfos[0],
@@ -324,7 +314,7 @@
 //---------------------------------------------------------------
 void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "SplitterQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "SplitterQueueDescriptor", 1);
 
     if (workloadInfo.m_OutputTensorInfos.size() <= 0)
     {
@@ -372,7 +362,7 @@
 //---------------------------------------------------------------
 void MergerQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleOutput(workloadInfo, "MergerQueueDescriptor");
+    ValidateNumOutputs(workloadInfo, "MergerQueueDescriptor", 1);
 
     if (m_Inputs.size() <= 0)
     {
@@ -444,8 +434,8 @@
 //---------------------------------------------------------------
 void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "FullyConnectedQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "FullyConnectedQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "FullyConnectedQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "FullyConnectedQueueDescriptor", 1);
     ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "FullyConnectedQueueDescriptor", 2, "output");
 
     if (!(workloadInfo.m_InputTensorInfos[0].GetNumDimensions() == 2 ||
@@ -487,8 +477,8 @@
 //---------------------------------------------------------------
 void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "NormalizationQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "NormalizationQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "NormalizationQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "NormalizationQueueDescriptor", 1);
     ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
                               workloadInfo.m_OutputTensorInfos[0],
                               "NormalizationQueueDescriptor",
@@ -498,8 +488,8 @@
 
 void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateTwoInputs(workloadInfo, "AdditionQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "AdditionQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "AdditionQueueDescriptor", 2);
+    ValidateNumOutputs(workloadInfo, "AdditionQueueDescriptor", 1);
 
     ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
                                        workloadInfo.m_InputTensorInfos[1],
@@ -513,8 +503,8 @@
 //---------------------------------------------------------------
 void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateTwoInputs(workloadInfo, "MultiplicationQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "MultiplicationQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "MultiplicationQueueDescriptor", 2);
+    ValidateNumOutputs(workloadInfo, "MultiplicationQueueDescriptor", 1);
 
     ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
                                        workloadInfo.m_InputTensorInfos[1],
@@ -526,8 +516,8 @@
 
 void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "BatchNormalizationQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "BatchNormalizationQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1);
     ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
                               workloadInfo.m_OutputTensorInfos[0],
                               "BatchNormalizationQueueDescriptor",
@@ -554,8 +544,8 @@
 
 void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "Convolution2dQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "Convolution2dQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "Convolution2dQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "Convolution2dQueueDescriptor", 1);
 
     ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "Convolution2dQueueDescriptor", 4, "input");
     ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "Convolution2dQueueDescriptor", 4, "output");
@@ -580,8 +570,8 @@
 
 void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "DepthwiseConvolution2dQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "DepthwiseConvolution2dQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "DepthwiseConvolution2dQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "DepthwiseConvolution2dQueueDescriptor", 1);
 
     ValidateTensorNumDimensions(
         workloadInfo.m_InputTensorInfos[0], "DepthwiseConvolution2dQueueDescriptor", 4, "input");
@@ -625,8 +615,8 @@
 
 void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "PermuteQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "PermuteQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "PermuteQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "PermuteQueueDescriptor", 1);
 
     const PermutationVector& mapping = m_Parameters.m_DimMappings;
 
@@ -650,8 +640,8 @@
 
 void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "Pooling2dQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "Pooling2dQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "Pooling2dQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "Pooling2dQueueDescriptor", 1);
 
     ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "Pooling2dQueueDescriptor", 4, "input");
     ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "Pooling2dQueueDescriptor", 4, "output");
@@ -659,8 +649,8 @@
 
 void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "ResizeBilinearQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "ResizeBilinearQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "ResizeBilinearQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "ResizeBilinearQueueDescriptor", 1);
 
     ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "ResizeBilinearQueueDescriptor", 4, "input");
     ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "ResizeBilinearQueueDescriptor", 4, "output");
@@ -694,8 +684,8 @@
 
 void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "FakeQuantizationQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "FakeQuantizationQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "FakeQuantizationQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "FakeQuantizationQueueDescriptor", 1);
 
     ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "FakeQuantizationQueueDescriptor", 2, "input");
     ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "FakeQuantizationQueueDescriptor", 2, "output");
@@ -713,8 +703,8 @@
 
 void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "L2NormalizationQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "L2NormalizationQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "L2NormalizationQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "L2NormalizationQueueDescriptor", 1);
 
     ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "L2NormalizationQueueDescriptor", 4, "input");
     ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "L2NormalizationQueueDescriptor", 4, "output");
@@ -727,8 +717,8 @@
 
 void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateNoInputs(workloadInfo, "ConstantQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "ConstantQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "ConstantQueueDescriptor", 0);
+    ValidateNumOutputs(workloadInfo, "ConstantQueueDescriptor", 1);
 
     if (!m_LayerOutput)
     {
@@ -744,8 +734,8 @@
 
 void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "ReshapeQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "ReshapeQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "ReshapeQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "ReshapeQueueDescriptor", 1);
 
     if (workloadInfo.m_InputTensorInfos[0].GetNumElements() != workloadInfo.m_OutputTensorInfos[0].GetNumElements())
     {
@@ -757,8 +747,8 @@
 
 void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "SpaceToBatchNdQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "SpaceToBatchNdQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "SpaceToBatchNdQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "SpaceToBatchNdQueueDescriptor", 1);
 
     ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "SpaceToBatchNdQueueDescriptor", 4, "input");
     ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "SpaceToBatchNdQueueDescriptor", 4, "output");
@@ -804,8 +794,8 @@
 
 void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "FloorQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "FlootQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "FloorQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "FlootQueueDescriptor", 1);
 
     if (workloadInfo.m_InputTensorInfos[0] != workloadInfo.m_OutputTensorInfos[0])
     {
@@ -821,8 +811,8 @@
 
 void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "ConvertFp32ToFp16QueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "ConvertFp32ToFp16QueueDescriptor");
+    ValidateNumInputs(workloadInfo, "ConvertFp32ToFp16QueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "ConvertFp32ToFp16QueueDescriptor", 1);
 
     if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::Float32)
     {
@@ -843,8 +833,8 @@
 
 void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "ConvertFp16ToFp32QueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "ConvertFp16ToFp32QueueDescriptor");
+    ValidateNumInputs(workloadInfo, "ConvertFp16ToFp32QueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "ConvertFp16ToFp32QueueDescriptor", 1);
 
     if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::Float16)
     {
@@ -864,8 +854,8 @@
 
 void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateTwoInputs(workloadInfo, "DivisionQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "DivisionQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "DivisionQueueDescriptor", 2);
+    ValidateNumOutputs(workloadInfo, "DivisionQueueDescriptor", 1);
 
     ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
                                        workloadInfo.m_InputTensorInfos[1],
@@ -877,8 +867,8 @@
 
 void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateTwoInputs(workloadInfo, "SubtractionQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "SubtractionQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "SubtractionQueueDescriptor", 2);
+    ValidateNumOutputs(workloadInfo, "SubtractionQueueDescriptor", 1);
 
     ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
                                        workloadInfo.m_InputTensorInfos[1],
@@ -890,8 +880,8 @@
 
 void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateTwoInputs(workloadInfo, "MaximumQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "MaximumQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "MaximumQueueDescriptor", 2);
+    ValidateNumOutputs(workloadInfo, "MaximumQueueDescriptor", 1);
 
     ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
                                        workloadInfo.m_InputTensorInfos[1],
@@ -903,8 +893,8 @@
 
 void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "MeanQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "MeanQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "MeanQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "MeanQueueDescriptor", 1);
 
     const TensorInfo& input  = workloadInfo.m_InputTensorInfos[0];
     const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0];
@@ -929,8 +919,8 @@
 
 void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "PadQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "PadQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "PadQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "PadQueueDescriptor", 1);
 
     const TensorInfo& input = workloadInfo.m_InputTensorInfos[0];
     const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0];
@@ -948,8 +938,8 @@
 
 void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "QuantizeQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "QuantizeQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "QuantizeQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "QuantizeQueueDescriptor", 1);
 
 
     if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::Float32)
@@ -966,14 +956,14 @@
 
 void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "BatchToSpaceNdQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "BatchToSpaceNdQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "BatchToSpaceNdQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "BatchToSpaceNdQueueDescriptor", 1);
 }
 
 void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "StridedSliceQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "StridedSliceQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "StridedSliceQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "StridedSliceQueueDescriptor", 1);
 
     const TensorInfo& input = workloadInfo.m_InputTensorInfos[0];
     const uint32_t rank = input.GetNumDimensions();
@@ -1015,8 +1005,8 @@
 
 void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateTwoInputs(workloadInfo, "MinimumQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "MinimumQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "MinimumQueueDescriptor", 2);
+    ValidateNumOutputs(workloadInfo, "MinimumQueueDescriptor", 1);
 
     ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
                                        workloadInfo.m_InputTensorInfos[1],
@@ -1028,14 +1018,14 @@
 
 void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "DebugQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "DebugQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "DebugQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "DebugQueueDescriptor", 1);
 }
 
 void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateTwoInputs(workloadInfo, "EqualQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "EqualQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "EqualQueueDescriptor", 2);
+    ValidateNumOutputs(workloadInfo, "EqualQueueDescriptor", 1);
 
     ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
                                        workloadInfo.m_InputTensorInfos[1],
@@ -1052,8 +1042,8 @@
 
 void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateTwoInputs(workloadInfo, "GreaterQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "GreaterQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "GreaterQueueDescriptor", 2);
+    ValidateNumOutputs(workloadInfo, "GreaterQueueDescriptor", 1);
 
     ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
                                        workloadInfo.m_InputTensorInfos[1],
@@ -1070,8 +1060,8 @@
 
 void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "RsqrtQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "RsqrtQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "RsqrtQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "RsqrtQueueDescriptor", 1);
     ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
                               workloadInfo.m_OutputTensorInfos[0],
                               "RsqrtQueueDescriptor",
@@ -1081,8 +1071,8 @@
 
 void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateTwoInputs(workloadInfo, "GatherQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "GatherQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "GatherQueueDescriptor", 2);
+    ValidateNumOutputs(workloadInfo, "GatherQueueDescriptor", 1);
 
     const TensorInfo& indices = workloadInfo.m_InputTensorInfos[1];
 
@@ -1102,7 +1092,7 @@
 
 void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateTwoInputs(workloadInfo, "DetectionPostProcessQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "DetectionPostProcessQueueDescriptor", 2);
 
     if (workloadInfo.m_OutputTensorInfos.size() != 4)
     {
@@ -1155,8 +1145,8 @@
 
 void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateSingleInput(workloadInfo, "DequantizeQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "DequantizeQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "DequantizeQueueDescriptor", 1);
+    ValidateNumOutputs(workloadInfo, "DequantizeQueueDescriptor", 1);
 
     if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::QuantisedAsymm8 &&
         workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::QuantisedSymm16)
@@ -1172,8 +1162,8 @@
 
 void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateTwoInputs(workloadInfo, "MergeQueueDescriptor");
-    ValidateSingleOutput(workloadInfo, "MergeQueueDescriptor");
+    ValidateNumInputs(workloadInfo, "MergeQueueDescriptor", 2);
+    ValidateNumOutputs(workloadInfo, "MergeQueueDescriptor", 1);
 
     ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
                               workloadInfo.m_InputTensorInfos[1],
@@ -1192,6 +1182,42 @@
     ValidateTensorDataType(workloadInfo.m_OutputTensorInfos[0], dataType, "MergeQueueDescriptor", "output");
 }
 
+void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+    ValidateNumInputs(workloadInfo, "SwitchQueueDescriptor", 2);
+    ValidateNumOutputs(workloadInfo, "SwitchQueueDescriptor", 2);
+
+    std::vector<DataType> supportedTypes = {
+        DataType::Float32,
+        DataType::QuantisedAsymm8,
+        DataType::QuantisedSymm16
+    };
+
+    ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+                      supportedTypes,
+                      "SwitchQueueDescriptor");
+
+    ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+                      supportedTypes,
+                      "SwitchQueueDescriptor");
+
+    ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+                      supportedTypes,
+                      "SwitchQueueDescriptor");
+
+    ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
+                              workloadInfo.m_OutputTensorInfos[0],
+                              "SwitchQueueDescriptor",
+                              "input0",
+                              "output0");
+
+    ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
+                              workloadInfo.m_OutputTensorInfos[1],
+                              "SwitchQueueDescriptor",
+                              "input0",
+                              "output1");
+}
+
 void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
     // This is internally generated so it should not need validation.
diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp
index 1bf7352..1b5f86d 100644
--- a/src/backends/backendsCommon/WorkloadData.hpp
+++ b/src/backends/backendsCommon/WorkloadData.hpp
@@ -426,4 +426,9 @@
     void Validate(const WorkloadInfo& workloadInfo) const;
 };
 
+struct SwitchQueueDescriptor : QueueDescriptor
+{
+    void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
 } //namespace armnn
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 4ea3ea9..d9774b0 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -729,6 +729,19 @@
                                             reason);
             break;
         }
+        case LayerType::Switch:
+        {
+            const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+            const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
+            const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
+            const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
+            result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
+                                                           OverrideDataType(input1, dataType),
+                                                           OverrideDataType(output0, dataType),
+                                                           OverrideDataType(output1, dataType),
+                                                           reason);
+            break;
+        }
         case LayerType::Mean:
         {
             auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
@@ -1041,4 +1054,10 @@
     return std::unique_ptr<IWorkload>();
 }
 
+std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& descriptor,
+                                                          const WorkloadInfo& info) const
+{
+    return std::unique_ptr<IWorkload>();
+}
+
 }
diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp
index 889bc9d..5c07b3a 100644
--- a/src/backends/backendsCommon/WorkloadFactory.hpp
+++ b/src/backends/backendsCommon/WorkloadFactory.hpp
@@ -177,6 +177,9 @@
 
     virtual std::unique_ptr<IWorkload> CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
                                                           const WorkloadInfo& Info) const;
+
+    virtual std::unique_ptr<IWorkload> CreateSwitch(const SwitchQueueDescriptor& descriptor,
+                                                    const WorkloadInfo& Info) const;
 };
 
 } //namespace armnn
diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
index 0588607..a7d7b09 100644
--- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
@@ -402,6 +402,8 @@
 
 DECLARE_LAYER_POLICY_1_PARAM(Subtraction)
 
+DECLARE_LAYER_POLICY_1_PARAM(Switch)
+
 
 // Generic implementation to get the number of input slots for a given layer type;
 template<armnn::LayerType Type>