blob: e29aeb1057c43ca20faeb8d487e04cae323b34b0 [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12struct BiasAddFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
13{
14 explicit BiasAddFixture(const std::string& dataFormat)
15 {
16 m_Prototext = R"(
17node {
18 name: "graphInput"
19 op: "Placeholder"
20 attr {
21 key: "dtype"
22 value {
23 type: DT_FLOAT
24 }
25 }
26 attr {
27 key: "shape"
28 value {
29 shape {
30 }
31 }
32 }
33}
34node {
35 name: "bias"
36 op: "Const"
37 attr {
38 key: "dtype"
39 value {
40 type: DT_FLOAT
41 }
42 }
43 attr {
44 key: "value"
45 value {
46 tensor {
47 dtype: DT_FLOAT
48 tensor_shape {
49 dim {
50 size: 3
51 }
52 }
53 float_val: 1
54 float_val: 2
55 float_val: 3
56 }
57 }
58 }
59}
60node {
61 name: "biasAdd"
62 op : "BiasAdd"
63 input: "graphInput"
64 input: "bias"
65 attr {
66 key: "T"
67 value {
68 type: DT_FLOAT
69 }
70 }
71 attr {
72 key: "data_format"
73 value {
74 s: ")" + dataFormat + R"("
75 }
76 }
77}
78)";
79
80 SetupSingleInputSingleOutput({ 1, 3, 1, 3 }, "graphInput", "biasAdd");
81 }
82};
83
84struct BiasAddFixtureNCHW : BiasAddFixture
85{
86 BiasAddFixtureNCHW() : BiasAddFixture("NCHW") {}
87};
88
89struct BiasAddFixtureNHWC : BiasAddFixture
90{
91 BiasAddFixtureNHWC() : BiasAddFixture("NHWC") {}
92};
93
94BOOST_FIXTURE_TEST_CASE(ParseBiasAddNCHW, BiasAddFixtureNCHW)
95{
96 RunTest<4>(std::vector<float>(9), { 1, 1, 1, 2, 2, 2, 3, 3, 3 });
97}
98
99BOOST_FIXTURE_TEST_CASE(ParseBiasAddNHWC, BiasAddFixtureNHWC)
100{
101 RunTest<4>(std::vector<float>(9), { 1, 2, 3, 1, 2, 3, 1, 2, 3 });
102}
103
104BOOST_AUTO_TEST_SUITE_END()