Add Split support to TOSA Reference Backend
* Resolves IVGCVSW-7918
Signed-off-by: Kevin May <kevin.may@arm.com>
Change-Id: Ic2afaa55f7ee88ce4c9b8ea696eef5f28663f8c6
diff --git a/src/backends/tosaReference/TosaRefLayerSupport.cpp b/src/backends/tosaReference/TosaRefLayerSupport.cpp
index 04be52d..ec6fc3b 100644
--- a/src/backends/tosaReference/TosaRefLayerSupport.cpp
+++ b/src/backends/tosaReference/TosaRefLayerSupport.cpp
@@ -74,9 +74,20 @@
case LayerType::Resize:
case LayerType::Slice:
case LayerType::Transpose:
+ {
inputInfos.push_back(&infos[0]);
outputInfos.push_back(&infos[1]);
break;
+ }
+ case LayerType::Splitter:
+ {
+ inputInfos.push_back(&infos[0]);
+ for (unsigned int i = 1; i < infos.size(); ++i)
+ {
+ outputInfos.push_back(&infos[i]);
+ }
+ break;
+ }
case LayerType::TransposeConvolution2d:
{
inputInfos.push_back(&infos[0]); // input
diff --git a/src/backends/tosaReference/test/TosaRefEndToEndTests.cpp b/src/backends/tosaReference/test/TosaRefEndToEndTests.cpp
index b35dacb..ae90c66 100644
--- a/src/backends/tosaReference/test/TosaRefEndToEndTests.cpp
+++ b/src/backends/tosaReference/test/TosaRefEndToEndTests.cpp
@@ -14,6 +14,7 @@
#include "backendsCommon/test/ResizeEndToEndTestImpl.hpp"
#include "backendsCommon/test/ElementwiseUnaryEndToEndTestImpl.hpp"
#include "backendsCommon/test/SliceEndToEndTestImpl.hpp"
+#include "backendsCommon/test/SplitterEndToEndTestImpl.hpp"
#include "backendsCommon/test/SubtractionEndToEndTestImpl.hpp"
#include "backendsCommon/test/TransposeConvolution2dEndToEndTestImpl.hpp"
#include "backendsCommon/test/TransposeEndToEndTestImpl.hpp"
@@ -202,6 +203,129 @@
{
SliceEndToEndFloat16<DataType::Float16>(tosaDefaultBackends);
}
+
+// Split
+TEST_CASE("TosaRefSplit1dEndtoEndTestBoolean")
+{
+ Splitter1dEndToEnd<DataType::Boolean>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit1dEndtoEndTestInt8")
+{
+ Splitter1dEndToEnd<DataType::QSymmS8>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit1dEndtoEndTestSigned16")
+{
+ Splitter1dEndToEnd<DataType::QSymmS16>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit1dEndtoEndTestInt32")
+{
+ Splitter1dEndToEnd<DataType::Signed32>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit1dEndtoEndTestFloat16")
+{
+ Splitter1dEndToEndFloat16<DataType::Float16>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit1dEndToEndFloat32")
+{
+ Splitter1dEndToEnd<DataType::Float32>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit2dDim0EndtoEndTestFloat32")
+{
+ Splitter2dDim0EndToEnd<DataType::Float32>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit2dDim1EndtoEndTestFloat32")
+{
+ Splitter2dDim1EndToEnd<DataType::Float32>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit3dDim0EndtoEndTestFloat32")
+{
+ Splitter3dDim0EndToEnd<DataType::Float32>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit3dDim1EndtoEndTestFloat32")
+{
+ Splitter3dDim1EndToEnd<DataType::Float32>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit3dDim1EndtoEndTestFloat16")
+{
+ Splitter3dDim1EndToEndFloat16<DataType::Float16>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit3dDim1EndtoEndTestBoolean")
+{
+ Splitter3dDim1EndToEnd<DataType::Boolean>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit3dDim1EndtoEndTestInt8")
+{
+ Splitter3dDim1EndToEnd<DataType::QSymmS8>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit3dDim1EndtoEndTestSigned16")
+{
+ Splitter3dDim1EndToEnd<DataType::QSymmS16>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit3dDim1EndtoEndTestInt32")
+{
+ Splitter3dDim1EndToEnd<DataType::Signed32>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit3dDim2EndtoEndTestInt8")
+{
+ Splitter3dDim2EndToEnd<DataType::QAsymmS8>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit4dDim0EndtoEndTestInt8")
+{
+ Splitter4dDim0EndToEnd<DataType::QSymmS8>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit4dDim1EndtoEndTestInt8")
+{
+ Splitter4dDim1EndToEnd<DataType::QSymmS8>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit4dDim2EndtoEndTestBoolean")
+{
+ Splitter4dDim2EndToEnd<DataType::Boolean>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit4dDim2EndtoEndTestInt8")
+{
+ Splitter4dDim2EndToEnd<DataType::QSymmS8>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit4dDim2EndtoEndTestInt16")
+{
+ Splitter4dDim2EndToEnd<DataType::QSymmS16>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit4dDim2EndtoEndTestInt32")
+{
+ Splitter4dDim2EndToEnd<DataType::Signed32>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit4dDim2EndtoEndTestFloat16")
+{
+ Splitter4dDim2EndToEndFloat16<DataType::Float16>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefSplit4dDim3EndtoEndTestInt8")
+{
+ Splitter4dDim3EndToEnd<DataType::QSymmS8>(tosaDefaultBackends);
+}
+
+// Subtraction
TEST_CASE("TosaRefSubtractionEndtoEndTestFloat32")
{
SubtractionEndToEnd<DataType::Float32>(tosaDefaultBackends);
diff --git a/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp b/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp
index fb4c84f..6f038ab 100644
--- a/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp
+++ b/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp
@@ -484,6 +484,54 @@
CHECK(!supported);
}
+TEST_CASE("IsLayerSupportedTosaReferenceSplit")
+{
+ TensorShape inShape = {1, 18, 4, 4};
+ TensorShape outShape = {1, 6, 4, 4};
+ TensorInfo in(inShape, DataType::Float32);
+ TensorInfo out(outShape, DataType::Float32);
+
+ const unsigned int numViews = 3;
+ const unsigned int numDimensions = 4;
+ armnn::SplitterDescriptor descriptor(numViews, numDimensions);
+ descriptor.SetAxis(static_cast<int32_t>(1));
+
+ TosaRefLayerSupport supportChecker;
+ std::string reasonIfNotSupported;
+ auto supported = supportChecker.IsLayerSupported(LayerType::Splitter,
+ {in, out},
+ descriptor,
+ EmptyOptional(),
+ EmptyOptional(),
+ reasonIfNotSupported);
+
+ CHECK(supported);
+}
+
+TEST_CASE("IsLayerSupportedTosaReferenceSplitUnsupported")
+{
+ TensorShape inShape = {1, 18, 4, 4};
+ TensorShape outShape = {1, 6, 4, 4};
+ TensorInfo in(inShape, DataType::Signed64);
+ TensorInfo out(outShape, DataType::Signed64);
+
+ const unsigned int numViews = 3;
+ const unsigned int numDimensions = 4;
+ armnn::SplitterDescriptor descriptor(numViews, numDimensions);
+ descriptor.SetAxis(static_cast<int32_t>(1));
+
+ TosaRefLayerSupport supportChecker;
+ std::string reasonIfNotSupported;
+ auto supported = supportChecker.IsLayerSupported(LayerType::Splitter,
+ {in, out},
+ descriptor,
+ EmptyOptional(),
+ EmptyOptional(),
+ reasonIfNotSupported);
+
+ CHECK(!supported);
+}
+
TEST_CASE("IsLayerSupportedTosaReferenceSubtraction")
{
TensorShape shape0 = {1,1,3,4};
diff --git a/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp b/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp
index 5e4103a..26bd29c 100644
--- a/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp
+++ b/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
diff --git a/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.hpp b/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.hpp
index 337e8f9..1ea4d8d 100644
--- a/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.hpp
+++ b/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//