blob: bd13a0984d76509a776a5e53db47f9d068a8eaab [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include <boost/test/unit_test.hpp>
6#include "armnnCaffeParser/ICaffeParser.hpp"
7#include "armnn/IRuntime.hpp"
8#include "armnn/INetwork.hpp"
9#include "armnn/Exceptions.hpp"
10
11#include "test/TensorHelpers.hpp"
12
13#include <string>
14
15#include "ParserPrototxtFixture.hpp"
16
17BOOST_AUTO_TEST_SUITE(CaffeParser)
18
19
20BOOST_AUTO_TEST_CASE(InputShapes)
21{
22 std::string explicitInput = "name: \"Minimal\"\n"
23 "layer {\n"
24 " name: \"data\"\n"
25 " type: \"Input\"\n"
26 " top: \"data\"\n"
27 " input_param { shape: { dim: 1 dim: 2 dim: 3 dim: 4 } }\n"
28 "}";
29 std::string implicitInput = "name: \"Minimal\"\n"
30 "input: \"data\" \n"
31 "input_dim: 1 \n"
32 "input_dim: 2 \n"
33 "input_dim: 3 \n"
34 "input_dim: 4 \n";
35 std::string implicitInputNoShape = "name: \"Minimal\"\n"
36 "input: \"data\" \n";
37
telsoa01c577f2c2018-08-31 09:22:23 +010038 armnn::IRuntime::CreationOptions options;
39 armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
telsoa014fcda012018-03-09 14:13:49 +000040 armnnCaffeParser::ICaffeParserPtr parser(armnnCaffeParser::ICaffeParser::Create());
41 armnn::INetworkPtr network(nullptr, nullptr);
42 armnn::NetworkId netId;
43
44 // Check everything works normally
David Beckf0b48452018-10-19 15:20:56 +010045 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
telsoa014fcda012018-03-09 14:13:49 +000046 {
47 network = parser->CreateNetworkFromString(explicitInput.c_str(), {}, { "data" });
48 BOOST_TEST(network.get());
telsoa01c577f2c2018-08-31 09:22:23 +010049 runtime->LoadNetwork(netId, Optimize(*network, backends, runtime->GetDeviceSpec()));
telsoa014fcda012018-03-09 14:13:49 +000050
51 armnnCaffeParser::BindingPointInfo inputBindingInfo = parser->GetNetworkInputBindingInfo("data");
52 armnn::TensorInfo inputTensorInfo = inputBindingInfo.second;
53 BOOST_TEST((inputTensorInfo == runtime->GetInputTensorInfo(netId, inputBindingInfo.first)));
54
55 BOOST_TEST(inputTensorInfo.GetShape()[0] == 1);
56 BOOST_TEST(inputTensorInfo.GetShape()[1] == 2);
57 BOOST_TEST(inputTensorInfo.GetShape()[2] == 3);
58 BOOST_TEST(inputTensorInfo.GetShape()[3] == 4);
59 }
60
telsoa01c577f2c2018-08-31 09:22:23 +010061 // Checks everything works with implicit input.
telsoa014fcda012018-03-09 14:13:49 +000062 {
63 network = parser->CreateNetworkFromString(implicitInput.c_str(), {}, { "data" });
64 BOOST_TEST(network.get());
telsoa01c577f2c2018-08-31 09:22:23 +010065 runtime->LoadNetwork(netId, Optimize(*network, backends, runtime->GetDeviceSpec()));
telsoa014fcda012018-03-09 14:13:49 +000066
67 armnnCaffeParser::BindingPointInfo inputBindingInfo = parser->GetNetworkInputBindingInfo("data");
68 armnn::TensorInfo inputTensorInfo = inputBindingInfo.second;
69 BOOST_TEST((inputTensorInfo == runtime->GetInputTensorInfo(netId, inputBindingInfo.first)));
70
71 BOOST_TEST(inputTensorInfo.GetShape()[0] == 1);
72 BOOST_TEST(inputTensorInfo.GetShape()[1] == 2);
73 BOOST_TEST(inputTensorInfo.GetShape()[2] == 3);
74 BOOST_TEST(inputTensorInfo.GetShape()[3] == 4);
75 }
76
telsoa01c577f2c2018-08-31 09:22:23 +010077 // Checks everything works with implicit and passing shape.
telsoa014fcda012018-03-09 14:13:49 +000078 {
79 network = parser->CreateNetworkFromString(implicitInput.c_str(), { {"data", { 2, 2, 3, 4 } } }, { "data" });
80 BOOST_TEST(network.get());
telsoa01c577f2c2018-08-31 09:22:23 +010081 runtime->LoadNetwork(netId, Optimize(*network, backends, runtime->GetDeviceSpec()));
telsoa014fcda012018-03-09 14:13:49 +000082
83 armnnCaffeParser::BindingPointInfo inputBindingInfo = parser->GetNetworkInputBindingInfo("data");
84 armnn::TensorInfo inputTensorInfo = inputBindingInfo.second;
85 BOOST_TEST((inputTensorInfo == runtime->GetInputTensorInfo(netId, inputBindingInfo.first)));
86
87 BOOST_TEST(inputTensorInfo.GetShape()[0] == 2);
88 BOOST_TEST(inputTensorInfo.GetShape()[1] == 2);
89 BOOST_TEST(inputTensorInfo.GetShape()[2] == 3);
90 BOOST_TEST(inputTensorInfo.GetShape()[3] == 4);
91 }
92
telsoa01c577f2c2018-08-31 09:22:23 +010093 // Checks everything works with implicit (no shape) and passing shape.
telsoa014fcda012018-03-09 14:13:49 +000094 {
95 network = parser->CreateNetworkFromString(implicitInputNoShape.c_str(), {{"data", {2, 2, 3, 4} }}, { "data" });
96 BOOST_TEST(network.get());
telsoa01c577f2c2018-08-31 09:22:23 +010097 runtime->LoadNetwork(netId, Optimize(*network, backends, runtime->GetDeviceSpec()));
telsoa014fcda012018-03-09 14:13:49 +000098
99 armnnCaffeParser::BindingPointInfo inputBindingInfo = parser->GetNetworkInputBindingInfo("data");
100 armnn::TensorInfo inputTensorInfo = inputBindingInfo.second;
101 BOOST_TEST((inputTensorInfo == runtime->GetInputTensorInfo(netId, inputBindingInfo.first)));
102
103 BOOST_TEST(inputTensorInfo.GetShape()[0] == 2);
104 BOOST_TEST(inputTensorInfo.GetShape()[1] == 2);
105 BOOST_TEST(inputTensorInfo.GetShape()[2] == 3);
106 BOOST_TEST(inputTensorInfo.GetShape()[3] == 4);
107 }
108
telsoa01c577f2c2018-08-31 09:22:23 +0100109 // Checks exception on incompatible shapes.
telsoa014fcda012018-03-09 14:13:49 +0000110 {
111 BOOST_CHECK_THROW(parser->CreateNetworkFromString(implicitInput.c_str(), {{"data",{ 2, 2, 3, 2 }}}, {"data"}),
112 armnn::ParseException);
113 }
114
telsoa01c577f2c2018-08-31 09:22:23 +0100115 // Checks exception when no shape available.
telsoa014fcda012018-03-09 14:13:49 +0000116 {
117 BOOST_CHECK_THROW(parser->CreateNetworkFromString(implicitInputNoShape.c_str(), {}, { "data" }),
118 armnn::ParseException);
119 }
120}
121
122BOOST_AUTO_TEST_SUITE_END()