IVGCVSW-5416 'Add android-nn-driver support for CAST
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: I02da912e5e4ca650b367ca40fe3f5ca5baa61cbb
diff --git a/src/backends/reference/workloads/RefCastWorkload.cpp b/src/backends/reference/workloads/RefCastWorkload.cpp
index 7080415..8f2a725 100644
--- a/src/backends/reference/workloads/RefCastWorkload.cpp
+++ b/src/backends/reference/workloads/RefCastWorkload.cpp
@@ -26,15 +26,38 @@
namespace armnn
{
- void RefCastWorkload::Execute() const
- {
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefCastWorkload_Execute");
- const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
- const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+void RefCastWorkload::Execute() const
+{
+ Execute(m_Data.m_Inputs, m_Data.m_Outputs);
+}
- Cast(*MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map()),
- *MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()),
- inputInfo.GetNumElements());
+void RefCastWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor)
+{
+ Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
+}
+
+void RefCastWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefCastWorkload_Execute");
+
+ TensorInfo inputTensorInfo(GetTensorInfo(inputs[0]));
+ TensorInfo outputTensorInfo(GetTensorInfo(outputs[0]));
+
+ // Quantization info should set to default values.
+ if (inputTensorInfo.IsQuantized())
+ {
+ inputTensorInfo.SetQuantizationScale(1.0f);
+ inputTensorInfo.SetQuantizationOffset(0);
}
+ if (outputTensorInfo.IsQuantized())
+ {
+ outputTensorInfo.SetQuantizationScale(1.0f);
+ outputTensorInfo.SetQuantizationOffset(0);
+ }
+
+ Cast(*MakeDecoder<float>(inputTensorInfo, inputs[0]->Map()),
+ *MakeEncoder<float>(outputTensorInfo, outputs[0]->Map()),
+ inputTensorInfo.GetNumElements());
+}
} //namespace armnn
\ No newline at end of file
diff --git a/src/backends/reference/workloads/RefCastWorkload.hpp b/src/backends/reference/workloads/RefCastWorkload.hpp
index 6742ef0..870fb41 100644
--- a/src/backends/reference/workloads/RefCastWorkload.hpp
+++ b/src/backends/reference/workloads/RefCastWorkload.hpp
@@ -18,6 +18,9 @@
public:
using BaseWorkload<CastQueueDescriptor>::BaseWorkload;
void Execute() const override;
+ void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
+private:
+ void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const;
};
} //namespace armnn