blob: 736e13c1ada00d73d163c3be0bbbbb80ab8db4b7 [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh01bceff2f2018-03-29 16:29:27 +01004//
5#include <boost/test/unit_test.hpp>
6#include "armnnTfParser/ITfParser.hpp"
7#include "ParserPrototxtFixture.hpp"
8
9BOOST_AUTO_TEST_SUITE(TensorflowParser)
10
telsoa01c577f2c2018-08-31 09:22:23 +010011struct PassThruFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +010012{
13 PassThruFixture()
14 {
15 m_Prototext = "node {\n"
16 " name: \"Placeholder\"\n"
17 " op: \"Placeholder\"\n"
18 " attr {\n"
19 " key: \"dtype\"\n"
20 " value {\n"
21 " type: DT_FLOAT\n"
22 " }\n"
23 " }\n"
24 " attr {\n"
25 " key: \"shape\"\n"
26 " value {\n"
27 " shape {\n"
28 " }\n"
29 " }\n"
30 " }\n"
31 "}\n";
32 SetupSingleInputSingleOutput({ 1, 7 }, "Placeholder", "Placeholder");
33 }
34};
35
36BOOST_FIXTURE_TEST_CASE(ValidateOutput, PassThruFixture)
37{
38 BOOST_TEST(m_Parser->GetNetworkOutputBindingInfo("Placeholder").second.GetNumDimensions() == 2);
39 BOOST_TEST(m_Parser->GetNetworkOutputBindingInfo("Placeholder").second.GetShape()[0] == 1);
40 BOOST_TEST(m_Parser->GetNetworkOutputBindingInfo("Placeholder").second.GetShape()[1] == 7);
41}
42
43BOOST_FIXTURE_TEST_CASE(RunGraph, PassThruFixture)
44{
45 armnn::TensorInfo inputTensorInfo = m_Parser->GetNetworkInputBindingInfo("Placeholder").second;
46 auto input = MakeRandomTensor<float, 2>(inputTensorInfo, 378346);
47 std::vector<float> inputVec;
48 inputVec.assign(input.data(), input.data() + input.num_elements());
telsoa01c577f2c2018-08-31 09:22:23 +010049 RunTest<2>(inputVec, inputVec); // The passthru network should output the same as the input.
surmeh01bceff2f2018-03-29 16:29:27 +010050}
51
52BOOST_AUTO_TEST_SUITE_END()