IVGCVSW-2254 Add Reference workload for Maximum
Change-Id: Id7302c6b1df995ebe6eb8eb94bab38bee1b31b0b
diff --git a/src/backends/backendsCommon/StringMapping.hpp b/src/backends/backendsCommon/StringMapping.hpp
index 6312e68..f6af821 100644
--- a/src/backends/backendsCommon/StringMapping.hpp
+++ b/src/backends/backendsCommon/StringMapping.hpp
@@ -19,6 +19,7 @@
enum Id {
RefAdditionWorkload_Execute,
RefSubtractionWorkload_Execute,
+ RefMaximumWorkload_Execute,
RefMultiplicationWorkload_Execute,
RefDivisionWorkload_Execute,
MAX_STRING_ID
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 00e4c5c..7222af6 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -308,6 +308,19 @@
&FalseFuncU8<>);
}
+bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ ignore_unused(input1);
+ ignore_unused(output);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ input0.GetDataType(),
+ &TrueFunc<>,
+ &TrueFunc<>);
+}
+
bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
const TensorInfo& output,
const MeanDescriptor& descriptor,
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index defa962..73e5394 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -116,6 +116,12 @@
const TensorInfo* cellToOutputWeights,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsMaximumSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
+
bool IsMeanSupported(const TensorInfo& input,
const TensorInfo& output,
const MeanDescriptor& descriptor,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index c93dc31..eef5b24 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -261,7 +261,7 @@
std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMaximum(
const MaximumQueueDescriptor& descriptor, const WorkloadInfo& info) const
{
- return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ return MakeWorkload<RefMaximumFloat32Workload, RefMaximumUint8Workload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMean(
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 86c5f90..b9c150f 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -17,6 +17,7 @@
ElementwiseFunction.hpp
FullyConnected.cpp
FullyConnected.hpp
+ Maximum.hpp
Merger.hpp
Pad.cpp
Pad.hpp
diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp
index bea3d2f..bb15049 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.cpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.cpp
@@ -7,6 +7,8 @@
#include "Broadcast.hpp"
#include <functional>
+#include "Maximum.hpp"
+
namespace armnn
{
@@ -27,3 +29,4 @@
template struct armnn::ElementwiseFunction<std::minus<float>>;
template struct armnn::ElementwiseFunction<std::multiplies<float>>;
template struct armnn::ElementwiseFunction<std::divides<float>>;
+template struct armnn::ElementwiseFunction<armnn::maximum<float>>;
\ No newline at end of file
diff --git a/src/backends/reference/workloads/Maximum.hpp b/src/backends/reference/workloads/Maximum.hpp
new file mode 100644
index 0000000..524afff
--- /dev/null
+++ b/src/backends/reference/workloads/Maximum.hpp
@@ -0,0 +1,22 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <iostream>
+
+namespace armnn
+{
+ template<typename T>
+ struct maximum
+ {
+ T
+ operator () (const T& inputData0, const T& inputData1) const
+ {
+ return std::max(inputData0, inputData1);
+ }
+ };
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
index 8e312a7..60a1b99 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
@@ -67,3 +67,6 @@
template class armnn::BaseFloat32ElementwiseWorkload<armnn::DivisionQueueDescriptor, std::divides<float>>;
template class armnn::BaseUint8ElementwiseWorkload<armnn::DivisionQueueDescriptor, std::divides<float>>;
+
+template class armnn::BaseFloat32ElementwiseWorkload<armnn::MaximumQueueDescriptor, armnn::maximum<float>>;
+template class armnn::BaseUint8ElementwiseWorkload<armnn::MaximumQueueDescriptor, armnn::maximum<float>>;
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
index 156613a..2772b77 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
@@ -9,6 +9,7 @@
#include <backendsCommon/StringMapping.hpp>
#include <backendsCommon/Workload.hpp>
#include <backendsCommon/WorkloadData.hpp>
+#include "Maximum.hpp"
namespace armnn
{
@@ -119,4 +120,17 @@
DivisionQueueDescriptor,
StringMapping::RefDivisionWorkload_Execute>;
+
+using RefMaximumFloat32Workload =
+ RefElementwiseWorkload<armnn::maximum<float>,
+ DataType::Float32,
+ MaximumQueueDescriptor,
+ StringMapping::RefMaximumWorkload_Execute>;
+
+using RefMaximumUint8Workload =
+ RefElementwiseWorkload<armnn::maximum<float>,
+ DataType::QuantisedAsymm8,
+ MaximumQueueDescriptor,
+ StringMapping::RefMaximumWorkload_Execute>;
+
} // armnn