IVGCVSW-1863 Support NHWC for L2Normalization
* Added L2NormalizationDescriptor struct with m_DataLyaout member
* Updated all IsL2NormalizationSupported calls to take a descriptor
as an argument
* Updated L2NormalizationLayer to take a descriptor as an argument
!android-nn-driver:150116
Change-Id: I0459352d19cfd269bc864a70cf73910bf44fdc01
diff --git a/src/backends/test/CreateWorkloadCl.cpp b/src/backends/test/CreateWorkloadCl.cpp
index 39bc259..cc0e12d 100644
--- a/src/backends/test/CreateWorkloadCl.cpp
+++ b/src/backends/test/CreateWorkloadCl.cpp
@@ -524,13 +524,14 @@
CreateMemCopyWorkloads<IClTensorHandle>(factory);
}
-BOOST_AUTO_TEST_CASE(CreateL2NormalizationWorkload)
+template <typename L2NormalizationWorkloadType, typename armnn::DataType DataType>
+static void ClL2NormalizationWorkloadTest(DataLayout dataLayout)
{
Graph graph;
ClWorkloadFactory factory;
- auto workload = CreateL2NormalizationWorkloadTest<ClL2NormalizationFloatWorkload, armnn::DataType::Float32>
- (factory, graph);
+ auto workload = CreateL2NormalizationWorkloadTest<L2NormalizationWorkloadType, DataType>
+ (factory, graph, dataLayout);
// Checks that inputs/outputs are as we expect them (see definition of CreateNormalizationWorkloadTest).
L2NormalizationQueueDescriptor queueDescriptor = workload->GetData();
@@ -541,6 +542,26 @@
BOOST_TEST(CompareIClTensorHandleShape(outputHandle, { 5, 20, 50, 67 }));
}
+BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloatNchwWorkload)
+{
+ ClL2NormalizationWorkloadTest<ClL2NormalizationFloatWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloatNhwcWorkload)
+{
+ ClL2NormalizationWorkloadTest<ClL2NormalizationFloatWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
+}
+
+BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat16NchwWorkload)
+{
+ ClL2NormalizationWorkloadTest<ClL2NormalizationFloatWorkload, armnn::DataType::Float16>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat16NhwcWorkload)
+{
+ ClL2NormalizationWorkloadTest<ClL2NormalizationFloatWorkload, armnn::DataType::Float16>(DataLayout::NHWC);
+}
+
template <typename LstmWorkloadType>
static void ClCreateLstmWorkloadTest()
{