IVGCVSW-6682 Add ReplaceTensorHandle functions to IWorkload and BaseWorkload

Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: I9f80b9f45206db920568e28e363fcb60f5c0819a
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index e865f25..6dbbd55 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -7,6 +7,7 @@
 
 #include <armnn/utility/PolymorphicDowncast.hpp>
 #include <reference/RefTensorHandle.hpp>
+#include <reference/RefTensorHandleFactory.hpp>
 #include <reference/RefWorkloadFactory.hpp>
 #include <reference/workloads/RefWorkloads.hpp>
 
@@ -46,7 +47,6 @@
     return RefWorkloadFactory(memoryManager);
 }
 
-
 }
 
 TEST_SUITE("CreateWorkloadRef")
@@ -1271,4 +1271,47 @@
     RefCreateQLstmWorkloadTest<RefQLstmWorkload>();
 }
 
+template <armnn::DataType DataType>
+static void RefCreateActivationWorkloadReplaceFunctionsTest()
+{
+    Graph graph;
+    RefWorkloadFactory factory = GetFactory();
+    // input and output are created as armnn::TensorInfo tensorInfo({1, 1}, DataType)
+    auto workloadPtr = CreateActivationWorkloadTest<RefActivationWorkload, DataType>(factory, graph);
+
+    // new input and output tensor handlers are created and then replace in the workload
+    shared_ptr<RefMemoryManager> memoryManager = make_shared<RefMemoryManager>();
+    const RefTensorHandleFactory tensorHandleFactory(memoryManager);
+    TensorInfo inputInfo({2 , 2}, DataType::Float16);
+    TensorInfo outputInfo({2 , 2}, DataType::Float16);
+    unique_ptr<ITensorHandle> inputHandle  = tensorHandleFactory.CreateTensorHandle(inputInfo);
+    unique_ptr<ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo);
+    unsigned int slot = 0;
+    workloadPtr->ReplaceInputTensorHandle(inputHandle.get(), slot);
+    workloadPtr->ReplaceOutputTensorHandle(outputHandle.get(), slot);
+
+    // Check if the tensor handlers inside the workload are the same as ones we replace with
+    auto queueDescriptor = workloadPtr->GetData();
+    auto inputHandleTest  = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
+    auto outputHandleTest = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
+    CHECK((inputHandleTest->GetTensorInfo() == inputInfo));
+    CHECK((outputHandleTest->GetTensorInfo() == outputInfo));
+    CHECK(inputHandle.get() == inputHandleTest);
+    CHECK(outputHandle.get() == outputHandleTest);
+    inputHandle->Allocate();
+    CHECK(inputHandle->Map() == inputHandleTest->Map());
+    outputHandle->Allocate();
+    CHECK(outputHandle->Map() == outputHandleTest->Map());
+}
+
+TEST_CASE("ReplaceFunctionsfromFloat32toFloat16ActivationWorkload")
+{
+    RefCreateActivationWorkloadReplaceFunctionsTest<armnn::DataType::Float32>();
+}
+
+TEST_CASE("ReplaceFunctionsfromUint8toFloat16ActivationWorkload")
+{
+    RefCreateActivationWorkloadReplaceFunctionsTest<armnn::DataType::QAsymmU8>();
+}
+
 }
diff --git a/src/backends/reference/workloads/RefActivationWorkload.hpp b/src/backends/reference/workloads/RefActivationWorkload.hpp
index e3bd870..9814ac1 100644
--- a/src/backends/reference/workloads/RefActivationWorkload.hpp
+++ b/src/backends/reference/workloads/RefActivationWorkload.hpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -16,7 +16,7 @@
 public:
     using BaseWorkload<ActivationQueueDescriptor>::BaseWorkload;
     void Execute() const override;
-    void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor)  override;
+    void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
 
 private:
     void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const;