blob: 2dce822b0fef67969941be9063791731c475cc49 [file] [log] [blame]
jimfly0184c70e62018-12-19 13:14:46 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12 struct EqualFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
14 EqualFixture()
15 {
16 m_Prototext = R"(
17node {
18 name: "input0"
19 op: "Placeholder"
20 attr {
21 key: "dtype"
22 value {
23 type: DT_FLOAT
24 }
25 }
26 attr {
27 key: "shape"
28 value {
29 shape {
30 }
31 }
32 }
33}
34node {
35 name: "input1"
36 op: "Placeholder"
37 attr {
38 key: "dtype"
39 value {
40 type: DT_FLOAT
41 }
42 }
43 attr {
44 key: "shape"
45 value {
46 shape {
47 }
48 }
49 }
50}
51node {
52 name: "output"
53 op: "Equal"
54 input: "input0"
55 input: "input1"
56 attr {
57 key: "T"
58 value {
59 type: DT_FLOAT
60 }
61 }
62}
63 )";
64 }
65 };
66
67BOOST_FIXTURE_TEST_CASE(ParseEqualUnsupportedBroadcast, EqualFixture)
68{
69 BOOST_REQUIRE_THROW(Setup({ { "input0", {2, 3} },
70 { "input1", {1, 2, 2, 3} } },
71 { "output" }),
72 armnn::ParseException);
73}
74
75struct EqualFixtureAutoSetup : public EqualFixture
76{
77 EqualFixtureAutoSetup(const armnn::TensorShape& input0Shape,
78 const armnn::TensorShape& input1Shape)
79 : EqualFixture()
80 {
81 Setup({ { "input0", input0Shape },
82 { "input1", input1Shape } },
83 { "output" });
84 }
85};
86
87struct EqualTwoByTwo : public EqualFixtureAutoSetup
88{
89 EqualTwoByTwo() : EqualFixtureAutoSetup({2,2}, {2,2}) {}
90};
91
92BOOST_FIXTURE_TEST_CASE(ParseEqualTwoByTwo, EqualTwoByTwo)
93{
kevmay012b4d88e2019-01-24 14:05:09 +000094 RunComparisonTest<2>({ { "input0", { 1.0f, 2.0f, 3.0f, 2.0f } },
95 { "input1", { 1.0f, 5.0f, 2.0f, 2.0f } } },
96 { { "output", { 1, 0, 0, 1 } } });
jimfly0184c70e62018-12-19 13:14:46 +000097}
98
99struct EqualBroadcast1DAnd4D : public EqualFixtureAutoSetup
100{
101 EqualBroadcast1DAnd4D() : EqualFixtureAutoSetup({1}, {1,1,2,2}) {}
102};
103
104BOOST_FIXTURE_TEST_CASE(ParseEqualBroadcast1DToTwoByTwo, EqualBroadcast1DAnd4D)
105{
kevmay012b4d88e2019-01-24 14:05:09 +0000106 RunComparisonTest<4>({ { "input0", { 2.0f } },
107 { "input1", { 1.0f, 2.0f, 3.0f, 2.0f } } },
108 { { "output", { 0, 1, 0, 1 } } });
jimfly0184c70e62018-12-19 13:14:46 +0000109}
110
111struct EqualBroadcast4DAnd1D : public EqualFixtureAutoSetup
112{
113 EqualBroadcast4DAnd1D() : EqualFixtureAutoSetup({1,1,2,2}, {1}) {}
114};
115
116BOOST_FIXTURE_TEST_CASE(ParseEqualBroadcast4DAnd1D, EqualBroadcast4DAnd1D)
117{
kevmay012b4d88e2019-01-24 14:05:09 +0000118 RunComparisonTest<4>({ { "input0", { 1.0f, 2.0f, 3.0f, 2.0f } },
119 { "input1", { 3.0f } } },
120 { { "output", { 0, 0, 1, 0 } } });
jimfly0184c70e62018-12-19 13:14:46 +0000121}
122
123struct EqualMultiDimBroadcast : public EqualFixtureAutoSetup
124{
125 EqualMultiDimBroadcast() : EqualFixtureAutoSetup({1,1,2,1}, {1,2,1,3}) {}
126};
127
128BOOST_FIXTURE_TEST_CASE(ParseEqualMultiDimBroadcast, EqualMultiDimBroadcast)
129{
kevmay012b4d88e2019-01-24 14:05:09 +0000130 RunComparisonTest<4>({ { "input0", { 1.0f, 2.0f } },
131 { "input1", { 1.0f, 2.0f, 3.0f,
132 3.0f, 2.0f, 2.0f } } },
133 { { "output", { 1, 0, 0,
134 0, 1, 0,
135 0, 0, 0,
136 0, 1, 1 } } });
jimfly0184c70e62018-12-19 13:14:46 +0000137}
138
139BOOST_AUTO_TEST_SUITE_END()