blob: 99f7fd270571b8ac8a3fb21a5f1298962b58c50f [file] [log] [blame]
Francis Murtagh9270d9e2022-08-12 13:54:17 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6
7#include <armnn/Optional.hpp>
8#include <armnn/Types.hpp>
9#include <tosaReference/TosaRefLayerSupport.hpp>
10
11#include <doctest/doctest.h>
12
13#include <string>
14
15TEST_SUITE("TosaRefLayerSupported")
16{
17
18TEST_CASE("IsLayerSupportedTosaReferenceAddition")
19{
20 armnn::TensorShape shape0 = {1,1,3,4};
21 armnn::TensorShape shape1 = {4};
22 armnn::TensorShape outShape = {1,1,3,4};
23 armnn::TensorInfo in0(shape0, armnn::DataType::Float32);
24 armnn::TensorInfo in1(shape1, armnn::DataType::Float32);
25 armnn::TensorInfo out(outShape, armnn::DataType::Float32);
26
27 armnn::BaseDescriptor desc;
28 armnn::TosaRefLayerSupport supportChecker;
29 std::string reasonIfNotSupported;
30 auto supported = supportChecker.IsLayerSupported(armnn::LayerType::Addition,
31 {in0, in1, out},
32 desc,
33 armnn::EmptyOptional(),
34 armnn::EmptyOptional(),
35 reasonIfNotSupported);
36
37 CHECK(supported);
38}
39
40TEST_CASE("IsLayerSupportedTosaReferenceAdditionUnsupported")
41{
42 armnn::TensorShape shape0 = {1,1,3,4};
43 armnn::TensorShape shape1 = {4};
44 armnn::TensorShape outShape = {1,1,3,4};
45 armnn::TensorInfo in0(shape0, armnn::DataType::Signed64);
46 armnn::TensorInfo in1(shape1, armnn::DataType::Signed64);
47 armnn::TensorInfo out(outShape, armnn::DataType::Signed64);
48
49 armnn::BaseDescriptor desc;
50 armnn::TosaRefLayerSupport supportChecker;
51 std::string reasonIfNotSupported;
52 auto supported = supportChecker.IsLayerSupported(armnn::LayerType::Addition,
53 {in0, in1, out},
54 desc,
55 armnn::EmptyOptional(),
56 armnn::EmptyOptional(),
57 reasonIfNotSupported);
58
59 CHECK(!supported);
60 REQUIRE(reasonIfNotSupported.find("TOSA Reference addition: Op_ADD_input0_") != std::string::npos);
61 REQUIRE(reasonIfNotSupported.find("TOSA Reference addition: Op_ADD_input1_") != std::string::npos);
62 REQUIRE(reasonIfNotSupported.find("TOSA Reference addition: Op_ADD_output0_") != std::string::npos);
63}
64
65}