IVGCVSW-2510 Ref workload implementation for Gather operator
* add implemenentation for GatherQueueDescriptor validate function
* add FirstInputTypedWorkload to allow type check on the first input tensor only
* add ref workload implemenentation for float and uint8
* add Gather layer support in Ref
* unit tests
Change-Id: I4578a3211f11d24aa29d15bcf7f45b0445bcd1ee
diff --git a/src/backends/backendsCommon/Workload.hpp b/src/backends/backendsCommon/Workload.hpp
index 309d53f..6539219 100644
--- a/src/backends/backendsCommon/Workload.hpp
+++ b/src/backends/backendsCommon/Workload.hpp
@@ -116,6 +116,7 @@
return it.GetDataType() == InputDataType;
}),
"Trying to create workload with incorrect type");
+
BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
info.m_OutputTensorInfos.end(),
[&](auto it){
@@ -125,6 +126,30 @@
}
};
+// FirstInputTypedWorkload used to check type of the first input
+template <typename QueueDescriptor, armnn::DataType DataType>
+class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
+{
+public:
+
+ FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
+ : BaseWorkload<QueueDescriptor>(descriptor, info)
+ {
+ if (!info.m_InputTensorInfos.empty())
+ {
+ BOOST_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
+ "Trying to create workload with incorrect type");
+ }
+
+ BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
+ info.m_OutputTensorInfos.end(),
+ [&](auto it){
+ return it.GetDataType() == DataType;
+ }),
+ "Trying to create workload with incorrect type");
+ }
+};
+
template <typename QueueDescriptor>
using FloatWorkload = TypedWorkload<QueueDescriptor,
armnn::DataType::Float16,