IVGCVSW-2871 Ref QuantizeLayer workload

Change-Id: If048b2a053c542b31ae344fe0af04d9b4f40eb6d
Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
diff --git a/src/armnn/layers/QuantizeLayer.cpp b/src/armnn/layers/QuantizeLayer.cpp
index fbf8b32..d5d76e2 100644
--- a/src/armnn/layers/QuantizeLayer.cpp
+++ b/src/armnn/layers/QuantizeLayer.cpp
@@ -19,7 +19,9 @@
 std::unique_ptr<IWorkload> QuantizeLayer::CreateWorkload(const Graph& graph,
                                                          const IWorkloadFactory& factory) const
 {
-    return nullptr;
+    QuantizeQueueDescriptor descriptor;
+    WorkloadInfo info = PrepInfoAndDesc(descriptor, graph);
+    return factory.CreateQuantize(descriptor, info);
 }
 
 Layer* QuantizeLayer::Clone(Graph& graph) const
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 532c8ea..4d164d5 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -122,6 +122,14 @@
     }
 };
 
+struct ShapesAreSameTotalSize : public Rule
+{
+    ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
+    {
+        m_Res = info0.GetNumElements() == info1.GetNumElements();
+    }
+};
+
 struct ShapesAreBroadcastCompatible : public Rule
 {
     unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
@@ -719,6 +727,34 @@
                                      &TrueFunc<>);
 }
 
+bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
+                                          const TensorInfo& output,
+                                          Optional<std::string&> reasonIfUnsupported) const
+{
+   bool supported = true;
+
+    // Define supported output types.
+    std::array<DataType,2> supportedInputTypes = {
+        DataType::Float32,
+    };
+
+    supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
+                                  "Reference quantize: input type not supported.");
+
+    // Define supported output types.
+    std::array<DataType,2> supportedOutputTypes = {
+        DataType::QuantisedAsymm8,
+        DataType::QuantisedSymm16
+    };
+    supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
+                                  "Reference quantize: output type not supported.");
+
+    supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
+                                  "Reference quantize: input and output shapes have different num total elements.");
+
+    return supported;
+}
+
 bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
                                          const ReshapeDescriptor& descriptor,
                                          Optional<std::string&> reasonIfUnsupported) const
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index 42a5a44..53a1abf 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -196,6 +196,10 @@
                               const Pooling2dDescriptor& descriptor,
                               Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
 
+    bool IsQuantizeSupported(const TensorInfo& input,
+                             const TensorInfo& output,
+                             Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
     bool IsReshapeSupported(const TensorInfo& input,
                             const ReshapeDescriptor& descriptor,
                             Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index dda1819..7fbd359 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -19,7 +19,6 @@
 {
 static const BackendId s_Id{RefBackendId()};
 }
-
 template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
 std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
     const WorkloadInfo& info) const
@@ -348,6 +347,12 @@
     return nullptr;
 }
 
+std::unique_ptr<IWorkload> RefWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
+                                                              const WorkloadInfo& info) const
+{
+    return std::make_unique<RefQuantizeWorkload>(descriptor, info);
+}
+
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDequantize(const DequantizeQueueDescriptor& descriptor,
                                                                 const WorkloadInfo& info) const
 {
diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp
index 14d3178..86f1ec3 100644
--- a/src/backends/reference/RefWorkloadFactory.hpp
+++ b/src/backends/reference/RefWorkloadFactory.hpp
@@ -180,6 +180,9 @@
     std::unique_ptr<IWorkload> CreateDequantize(const DequantizeQueueDescriptor& descriptor,
                                                 const WorkloadInfo& info) const override;
 
+    std::unique_ptr<IWorkload> CreateQuantize(const QuantizeQueueDescriptor& descriptor,
+                                              const WorkloadInfo& info) const override;
+
 private:
 
     template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 90aa63a..f2b1153 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -59,6 +59,7 @@
         workloads/RefPermuteWorkload.cpp \
         workloads/RefPooling2dFloat32Workload.cpp \
         workloads/RefPooling2dUint8Workload.cpp \
+        workloads/RefQuantizeWorkload.cpp \
         workloads/RefReshapeFloat32Workload.cpp \
         workloads/RefReshapeUint8Workload.cpp \
         workloads/RefResizeBilinearFloat32Workload.cpp \
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index c4fc202..4f5fbb5 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -97,6 +97,8 @@
     RefPooling2dFloat32Workload.hpp
     RefPooling2dUint8Workload.cpp
     RefPooling2dUint8Workload.hpp
+    RefQuantizeWorkload.cpp
+    RefQuantizeWorkload.hpp
     RefReshapeFloat32Workload.cpp
     RefReshapeFloat32Workload.hpp
     RefReshapeUint8Workload.cpp
diff --git a/src/backends/reference/workloads/RefQuantizeWorkload.cpp b/src/backends/reference/workloads/RefQuantizeWorkload.cpp
new file mode 100644
index 0000000..b7ace32
--- /dev/null
+++ b/src/backends/reference/workloads/RefQuantizeWorkload.cpp
@@ -0,0 +1,66 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefQuantizeWorkload.hpp"
+
+#include <armnn/TypesUtils.hpp>
+
+
+namespace armnn
+{
+
+namespace
+{
+
+template<typename T>
+void QuantizeImpl(const void *input, void *output, size_t numValues, float scale, int offset)
+{
+    auto in = static_cast<const float *>(input);
+    auto out = static_cast<T *>(output);
+    for (size_t i = 0; i < numValues; i++, in++, out++)
+    {
+        *out = armnn::Quantize<T>(*in, scale, offset);
+    }
+}
+
+} //namespace
+
+RefQuantizeWorkload::RefQuantizeWorkload(const QuantizeQueueDescriptor& descriptor, const WorkloadInfo &info)
+    : BaseWorkload(descriptor, info)
+    , m_NumElements(info.m_InputTensorInfos[0].GetNumElements())
+    , m_TargetType(info.m_OutputTensorInfos[0].GetDataType())
+    , m_Scale(info.m_OutputTensorInfos[0].GetQuantizationScale())
+    , m_Offset(info.m_OutputTensorInfos[0].GetQuantizationOffset())
+{
+}
+
+void RefQuantizeWorkload::Execute() const
+{
+    const void* input = m_Data.m_Inputs[0]->Map(true);
+    void* output =  m_Data.m_Outputs[0]->Map(true);
+
+    switch(m_TargetType)
+    {
+        case DataType::QuantisedAsymm8:
+        {
+            QuantizeImpl<uint8_t>(input, output, m_NumElements, m_Scale, m_Offset);
+            break;
+        }
+        case DataType::QuantisedSymm16:
+        {
+            QuantizeImpl<int16_t>(input, output, m_NumElements, m_Scale, 0);
+            break;
+        }
+        default:
+        {
+            BOOST_ASSERT_MSG(false, "RefQuantizeWorkload: Non quantized output type encountered");
+        }
+    }
+
+    m_Data.m_Inputs[0]->Unmap();
+    m_Data.m_Outputs[0]->Unmap();
+}
+
+} //namespace armnn
\ No newline at end of file
diff --git a/src/backends/reference/workloads/RefQuantizeWorkload.hpp b/src/backends/reference/workloads/RefQuantizeWorkload.hpp
new file mode 100644
index 0000000..6a43b84
--- /dev/null
+++ b/src/backends/reference/workloads/RefQuantizeWorkload.hpp
@@ -0,0 +1,26 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+
+namespace armnn {
+
+class RefQuantizeWorkload : public BaseWorkload<QuantizeQueueDescriptor>
+{
+public:
+    RefQuantizeWorkload(const QuantizeQueueDescriptor& descriptor, const WorkloadInfo &info);
+    void Execute() const override;
+
+private:
+    size_t m_NumElements;
+    armnn::DataType m_TargetType;
+    float m_Scale;
+    int m_Offset;
+};
+
+} //namespace armnn
\ No newline at end of file
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 7d2e813..77aa56f 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -64,3 +64,5 @@
 #include "RefRsqrtFloat32Workload.hpp"
 #include "RefComparisonWorkload.hpp"
 #include "RefDequantizeWorkload.hpp"
+
+#include "RefQuantizeWorkload.hpp"
\ No newline at end of file