IVGCVSW-3649 Add Prelu with different alpha dimension test to TfLiteParser
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I982ecd66ea3ed4d88934cd8254832eecb4a7adb4
diff --git a/docs/01_01_parsers.dox b/docs/01_01_parsers.dox
index af87eba..761380c 100644
--- a/docs/01_01_parsers.dox
+++ b/docs/01_01_parsers.dox
@@ -133,6 +133,7 @@
- NEG
- PACK
- PAD
+- PRELU
- QUANTIZE
- RELU
- RELU6
diff --git a/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp b/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
index 4cfe2e4..d243a80 100644
--- a/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
+++ b/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
@@ -17,7 +17,8 @@
static const std::set<armnn::LayerType> broadcastOps{ LayerType::Addition, LayerType::Division,
LayerType::Maximum, LayerType::Minimum,
- LayerType::Multiplication, LayerType::Subtraction };
+ LayerType::Multiplication, LayerType::Prelu,
+ LayerType::Subtraction };
class AddBroadcastReshapeLayerImpl
{
diff --git a/src/armnnTfLiteParser/test/Prelu.cpp b/src/armnnTfLiteParser/test/Prelu.cpp
index b4aa8d7..48a86dc 100644
--- a/src/armnnTfLiteParser/test/Prelu.cpp
+++ b/src/armnnTfLiteParser/test/Prelu.cpp
@@ -106,7 +106,7 @@
struct SimplePreluFixture : PreluFixture
{
SimplePreluFixture() : PreluFixture("[ 2, 3 ]",
- "[ 1, 1 ]",
+ "[ 1 ]",
"[ 2, 3 ]",
"[ 0, 1 ]",
"") {}
@@ -115,13 +115,23 @@
struct PreluConstAlphaFixture : PreluFixture
{
PreluConstAlphaFixture() : PreluFixture(
- "[ 2, 3 ]",
- "[ 2, 3 ]",
- "[ 2, 3 ]",
+ "[ 1, 2, 3 ]",
+ "[ 1, 2, 3 ]",
+ "[ 1, 2, 3 ]",
"[ 0 ]",
"\"data\": [ 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62 ]"){}
};
+struct PreluBroadcastAlphaFixture : PreluFixture
+{
+ PreluBroadcastAlphaFixture() : PreluFixture(
+ "[ 1, 1, 2, 3 ]",
+ "[ 1, 3 ]",
+ "[ 1, 1, 2, 3 ]",
+ "[ 0 ]",
+ "\"data\": [ 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62 ]"){}
+};
+
struct PreluDynamicTensorFixture : PreluFixture
{
PreluDynamicTensorFixture() : PreluFixture("[ 2, 3 ]",
@@ -141,7 +151,15 @@
BOOST_FIXTURE_TEST_CASE(PreluConstAlpha, PreluConstAlphaFixture)
{
- RunTest<2, armnn::DataType::Float32>(
+ RunTest<3, armnn::DataType::Float32>(
+ 0,
+ {{"input0", { -14.f, 2.f, 0.f, 1.f, -5.f, 14.f }}},
+ {{"output", { -3.5f, 2.f, 0.f, 1.f, -1.25f, 14.f }}});
+}
+
+BOOST_FIXTURE_TEST_CASE(PreluBroadcastAlpha, PreluBroadcastAlphaFixture)
+{
+ RunTest<4, armnn::DataType::Float32>(
0,
{{"input0", { -14.f, 2.f, 0.f, 1.f, -5.f, 14.f }}},
{{"output", { -3.5f, 2.f, 0.f, 1.f, -1.25f, 14.f }}});