blob: 2b3cbe65d8c3e2f46be3a88acb1ca779207e689e [file] [log] [blame]
jimfly0123be07e2018-12-04 17:47:22 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12struct SubFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13{
14 SubFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1)
15 {
16 m_Prototext = R"(
17node {
18 name: "input0"
19 op: "Placeholder"
20 attr {
21 key: "dtype"
22 value {
23 type: DT_FLOAT
24 }
25 }
26 attr {
27 key: "shape"
28 value {
29 shape {
30 }
31 }
32 }
33}
34node {
35 name: "input1"
36 op: "Placeholder"
37 attr {
38 key: "dtype"
39 value {
40 type: DT_FLOAT
41 }
42 }
43 attr {
44 key: "shape"
45 value {
46 shape {
47 }
48 }
49 }
50}
51node {
52 name: "output"
53 op: "Sub"
54 input: "input0"
55 input: "input1"
56 attr {
57 key: "T"
58 value {
59 type: DT_FLOAT
60 }
61 }
62}
63 )";
64 Setup({ { "input0", inputShape0 },
65 { "input1", inputShape1 } },
66 { "output" });
67
68 }
69};
70
71struct SubFixture4D4D : public SubFixture
72{
73 SubFixture4D4D() : SubFixture({ 1, 2, 2, 3 }, { 1, 2, 2, 3 }) {}
74};
75
76BOOST_FIXTURE_TEST_CASE(ParseSub, SubFixture4D4D)
77{
78 RunTest<4>({ { "input0", { 5.0f, 1.0f, 2.0f,
79 3.0f, 4.0f, 5.0f,
80 6.0f, 7.0f, 8.0f,
81 29.0f, 10.0f, 11.0f } },
82
83 { "input1", { 0.0f, 1.0f, 3.0f,
84 4.0f, 5.5f, 1.0f,
85 2.0f, 17.0f, 18.0f,
86 19.0f, 1.0f, 3.0f } } },
87
88 { { "output", { 5.0f, 0.0f, -1.0f,
89 -1.0f, -1.5f, 4.0f,
90 4.0f, -10.0f, -10.0f,
91 10.0f, 9.0f, 8.0f } } });
92}
93
94struct SubBroadcastFixture4D1D : public SubFixture
95{
96 SubBroadcastFixture4D1D() : SubFixture({ 1, 2, 2, 3 }, { 1 }) {}
97};
98
99BOOST_FIXTURE_TEST_CASE(ParseSubBroadcast4D1D, SubBroadcastFixture4D1D)
100{
101 RunTest<4>({ { "input0", { 0.0f, 1.0f, 2.0f,
102 3.0f, 4.0f, 5.0f,
103 6.0f, 7.0f, 8.0f,
104 9.0f, 10.0f, 11.0f } },
105
106 { "input1", { 5.0f } } },
107
108 { { "output", { -5.0f, -4.0f, -3.0f,
109 -2.0f, -1.0f, 0.0f,
110 1.0f, 2.0f, 3.0f,
111 4.0f, 5.0f, 6.0f } } });
112}
113
114struct SubBroadcastFixture1D4D : public SubFixture
115{
116 SubBroadcastFixture1D4D() : SubFixture({ 1 }, { 1, 2, 2, 3 }) {}
117};
118
119BOOST_FIXTURE_TEST_CASE(ParseSubBroadcast1D4D, SubBroadcastFixture1D4D)
120{
121 RunTest<4>({ { "input0", { 3.0f } },
122
123 { "input1", { 0.0f, 1.0f, 2.0f,
124 3.0f, 4.0f, 5.0f,
125 6.0f, 7.0f, 8.0f,
126 9.0f, 10.0f, 11.0f } } },
127
128 { { "output", { 3.0f, 2.0f, 1.0f,
129 0.0f, -1.0f, -2.0f,
130 -3.0f, -4.0f, -5.0f,
131 -6.0f, -7.0f, -8.0f } } });
132}
133
134
135BOOST_AUTO_TEST_SUITE_END()