Release 18.08
diff --git a/src/armnn/backends/Workload.hpp b/src/armnn/backends/Workload.hpp
index dbc7574..5da03bc 100644
--- a/src/armnn/backends/Workload.hpp
+++ b/src/armnn/backends/Workload.hpp
@@ -12,11 +12,11 @@
 namespace armnn
 {
 
-// Workload interface to enqueue a layer computation
+// Workload interface to enqueue a layer computation.
 class IWorkload
 {
 public:
-    virtual ~IWorkload(){};
+    virtual ~IWorkload() {}
 
     virtual void Execute() const = 0;
 };
@@ -46,7 +46,8 @@
     const QueueDescriptor m_Data;
 };
 
-template <typename QueueDescriptor, armnn::DataType DataType>
+// TypedWorkload used
+template <typename QueueDescriptor, armnn::DataType... DataTypes>
 class TypedWorkload : public BaseWorkload<QueueDescriptor>
 {
 public:
@@ -54,27 +55,93 @@
     TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
         : BaseWorkload<QueueDescriptor>(descriptor, info)
     {
+        std::vector<armnn::DataType> dataTypes = {DataTypes...};
+        armnn::DataType expectedInputType;
+
+        if (!info.m_InputTensorInfos.empty())
+        {
+            expectedInputType = info.m_InputTensorInfos.front().GetDataType();
+
+            if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
+            {
+                BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
+            }
+            BOOST_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
+                                         info.m_InputTensorInfos.end(),
+                                         [&](auto it){
+                                             return it.GetDataType() == expectedInputType;
+                                         }),
+                             "Trying to create workload with incorrect type");
+        }
+        armnn::DataType expectedOutputType;
+
+        if (!info.m_OutputTensorInfos.empty())
+        {
+            expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
+
+            if (!info.m_InputTensorInfos.empty())
+            {
+                if (expectedOutputType != expectedInputType)
+                {
+                    BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
+                }
+            }
+            else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
+            {
+                BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
+            }
+            BOOST_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
+                                         info.m_OutputTensorInfos.end(),
+                                         [&](auto it){
+                                             return it.GetDataType() == expectedOutputType;
+                                         }),
+                             "Trying to create workload with incorrect type");
+        }
+    }
+};
+
+template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
+class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
+{
+public:
+
+    MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
+        : BaseWorkload<QueueDescriptor>(descriptor, info)
+    {
         BOOST_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
                                      info.m_InputTensorInfos.end(),
                                      [&](auto it){
-                                         return it.GetDataType() == DataType;
+                                         return it.GetDataType() == InputDataType;
                                      }),
                          "Trying to create workload with incorrect type");
         BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
                                      info.m_OutputTensorInfos.end(),
                                      [&](auto it){
-                                         return it.GetDataType() == DataType;
+                                         return it.GetDataType() == OutputDataType;
                                      }),
                          "Trying to create workload with incorrect type");
     }
-
-    static constexpr armnn::DataType ms_DataType = DataType;
 };
 
 template <typename QueueDescriptor>
+using FloatWorkload = TypedWorkload<QueueDescriptor,
+                                    armnn::DataType::Float16,
+                                    armnn::DataType::Float32>;
+
+template <typename QueueDescriptor>
 using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
 
 template <typename QueueDescriptor>
 using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QuantisedAsymm8>;
 
+template <typename QueueDescriptor>
+using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
+                                                    armnn::DataType::Float16,
+                                                    armnn::DataType::Float32>;
+
+template <typename QueueDescriptor>
+using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
+                                                    armnn::DataType::Float32,
+                                                    armnn::DataType::Float16>;
+
 } //namespace armnn