IVGCVSW-3267 Add more code coverage to the PReLU layer
* Added more unit tests to cover all code branches
* Moved the InferOutput tests to separate files
* Created convenience ARMNN_SIMPLE_TEST_CASE macro
* Created TestUtils file for common utility functions
Change-Id: Id971d3cf77005397d1f0b2783fab68b1f0bf9dfc
Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 14615f8..d174093 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -870,32 +870,60 @@
RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::Signed32>({ 2, 3, 2, 10 });
}
-template <typename armnn::DataType DataType>
-static void RefCreatePreluWorkloadTest(const armnn::TensorShape& outputShape)
+static void RefCreatePreluWorkloadTest(const armnn::TensorShape& inputShape,
+ const armnn::TensorShape& alphaShape,
+ const armnn::TensorShape& outputShape,
+ armnn::DataType dataType)
{
armnn::Graph graph;
RefWorkloadFactory factory;
- auto workload = CreatePreluWorkloadTest<RefPreluWorkload, DataType>(factory, graph, outputShape);
+ auto workload = CreatePreluWorkloadTest<RefPreluWorkload>(factory,
+ graph,
+ inputShape,
+ alphaShape,
+ outputShape,
+ dataType);
// Check output is as expected
auto queueDescriptor = workload->GetData();
auto outputHandle = boost::polymorphic_downcast<CpuTensorHandle*>(queueDescriptor.m_Outputs[0]);
- BOOST_TEST((outputHandle->GetTensorInfo() == TensorInfo(outputShape, DataType)));
+ BOOST_TEST((outputHandle->GetTensorInfo() == TensorInfo(outputShape, dataType)));
}
BOOST_AUTO_TEST_CASE(CreatePreluFloat32Workload)
{
- RefCreatePreluWorkloadTest<armnn::DataType::Float32>({ 5, 4, 3, 2 });
+ RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::Float32);
}
BOOST_AUTO_TEST_CASE(CreatePreluUint8Workload)
{
- RefCreatePreluWorkloadTest<armnn::DataType::QuantisedAsymm8>({ 5, 4, 3, 2 });
+ RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::QuantisedAsymm8);
}
BOOST_AUTO_TEST_CASE(CreatePreluInt16Workload)
{
- RefCreatePreluWorkloadTest<armnn::DataType::QuantisedSymm16>({ 5, 4, 3, 2 });
+ RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::QuantisedSymm16);
+}
+
+BOOST_AUTO_TEST_CASE(CreatePreluFloat32NoBroadcastWorkload)
+{
+ BOOST_CHECK_THROW(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
+ armnn::DataType::Float32),
+ armnn::InvalidArgumentException);
+}
+
+BOOST_AUTO_TEST_CASE(CreatePreluUint8NoBroadcastWorkload)
+{
+ BOOST_CHECK_THROW(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
+ armnn::DataType::QuantisedAsymm8),
+ armnn::InvalidArgumentException);
+}
+
+BOOST_AUTO_TEST_CASE(CreatePreluInt16NoBroadcastWorkload)
+{
+ BOOST_CHECK_THROW(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
+ armnn::DataType::QuantisedSymm16),
+ armnn::InvalidArgumentException);
}
BOOST_AUTO_TEST_SUITE_END()