blob: 5bb3095cc7beeb8308d2d7e924ff1c0d44b4f942 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#include <boost/test/unit_test.hpp>
6#include "../OnnxParser.hpp"
7#include "ParserPrototxtFixture.hpp"
8#include <onnx/onnx.pb.h>
9#include "google/protobuf/stubs/logging.h"
10
11
12using ModelPtr = std::unique_ptr<onnx::ModelProto>;
13
14BOOST_AUTO_TEST_SUITE(OnnxParser)
15
16struct GetInputsOutputsMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
17{
18 explicit GetInputsOutputsMainFixture()
19 {
20 m_Prototext = R"(
21 ir_version: 3
22 producer_name: "CNTK"
23 producer_version: "2.5.1"
24 domain: "ai.cntk"
25 model_version: 1
26 graph {
27 name: "CNTKGraph"
28 input {
29 name: "Input"
30 type {
31 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000032 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010033 shape {
34 dim {
35 dim_value: 4
36 }
37 }
38 }
39 }
40 }
41 node {
42 input: "Input"
43 output: "Output"
44 name: "ActivationLayer"
45 op_type: "Relu"
46 }
47 output {
48 name: "Output"
49 type {
50 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000051 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010052 shape {
53 dim {
54 dim_value: 4
55 }
56 }
57 }
58 }
59 }
60 }
61 opset_import {
62 version: 7
63 })";
64 Setup();
65 }
66};
67
68
69BOOST_FIXTURE_TEST_CASE(GetInput, GetInputsOutputsMainFixture)
70{
Kevin Mayef33cb12021-01-29 14:24:57 +000071 ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str());
72 std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model);
telsoa01c577f2c2018-08-31 09:22:23 +010073 BOOST_CHECK_EQUAL(1, tensors.size());
74 BOOST_CHECK_EQUAL("Input", tensors[0]);
75
76}
77
78BOOST_FIXTURE_TEST_CASE(GetOutput, GetInputsOutputsMainFixture)
79{
Kevin Mayef33cb12021-01-29 14:24:57 +000080 ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str());
81 std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetOutputs(model);
telsoa01c577f2c2018-08-31 09:22:23 +010082 BOOST_CHECK_EQUAL(1, tensors.size());
83 BOOST_CHECK_EQUAL("Output", tensors[0]);
84}
85
86struct GetEmptyInputsOutputsFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
87{
88 GetEmptyInputsOutputsFixture()
89 {
90 m_Prototext = R"(
91 ir_version: 3
92 producer_name: "CNTK "
93 producer_version: "2.5.1 "
94 domain: "ai.cntk "
95 model_version: 1
96 graph {
97 name: "CNTKGraph "
98 node {
99 output: "Output"
100 attribute {
101 name: "value"
102 t {
103 dims: 7
Matteo Martincigh44a71672018-12-11 13:46:52 +0000104 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100105 float_data: 0.0
106 float_data: 1.0
107 float_data: 2.0
108 float_data: 3.0
109 float_data: 4.0
110 float_data: 5.0
111 float_data: 6.0
112
113 }
Matteo Martincigh44a71672018-12-11 13:46:52 +0000114 type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100115 }
116 name: "constantNode"
117 op_type: "Constant"
118 }
119 output {
120 name: "Output"
121 type {
122 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000123 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100124 shape {
125 dim {
126 dim_value: 7
127 }
128 }
129 }
130 }
131 }
132 }
133 opset_import {
134 version: 7
135 })";
136 Setup();
137 }
138};
139
140BOOST_FIXTURE_TEST_CASE(GetEmptyInputs, GetEmptyInputsOutputsFixture)
141{
Kevin Mayef33cb12021-01-29 14:24:57 +0000142 ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str());
143 std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model);
telsoa01c577f2c2018-08-31 09:22:23 +0100144 BOOST_CHECK_EQUAL(0, tensors.size());
145}
146
147BOOST_AUTO_TEST_CASE(GetInputsNullModel)
148{
Kevin Mayef33cb12021-01-29 14:24:57 +0000149 BOOST_CHECK_THROW(armnnOnnxParser::OnnxParserImpl::LoadModelFromString(""), armnn::InvalidArgumentException);
telsoa01c577f2c2018-08-31 09:22:23 +0100150}
151
152BOOST_AUTO_TEST_CASE(GetOutputsNullModel)
153{
154 auto silencer = google::protobuf::LogSilencer(); //get rid of errors from protobuf
Kevin Mayef33cb12021-01-29 14:24:57 +0000155 BOOST_CHECK_THROW(armnnOnnxParser::OnnxParserImpl::LoadModelFromString("nknnk"), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100156}
157
158struct GetInputsMultipleFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
159{
160 GetInputsMultipleFixture() {
161
162 m_Prototext = R"(
163 ir_version: 3
164 producer_name: "CNTK"
165 producer_version: "2.5.1"
166 domain: "ai.cntk"
167 model_version: 1
168 graph {
169 name: "CNTKGraph"
170 input {
171 name: "Input0"
172 type {
173 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000174 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100175 shape {
176 dim {
177 dim_value: 1
178 }
179 dim {
180 dim_value: 1
181 }
182 dim {
183 dim_value: 1
184 }
185 dim {
186 dim_value: 4
187 }
188 }
189 }
190 }
191 }
192 input {
193 name: "Input1"
194 type {
195 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000196 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100197 shape {
198 dim {
199 dim_value: 4
200 }
201 }
202 }
203 }
204 }
205 node {
206 input: "Input0"
207 input: "Input1"
208 output: "Output"
209 name: "addition"
210 op_type: "Add"
211 doc_string: ""
212 domain: ""
213 }
214 output {
215 name: "Output"
216 type {
217 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000218 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100219 shape {
220 dim {
221 dim_value: 1
222 }
223 dim {
224 dim_value: 1
225 }
226 dim {
227 dim_value: 1
228 }
229 dim {
230 dim_value: 4
231 }
232 }
233 }
234 }
235 }
236 }
237 opset_import {
238 version: 7
239 })";
240 Setup();
241 }
242};
243
244BOOST_FIXTURE_TEST_CASE(GetInputsMultipleInputs, GetInputsMultipleFixture)
245{
Kevin Mayef33cb12021-01-29 14:24:57 +0000246 ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str());
247 std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model);
telsoa01c577f2c2018-08-31 09:22:23 +0100248 BOOST_CHECK_EQUAL(2, tensors.size());
249 BOOST_CHECK_EQUAL("Input0", tensors[0]);
250 BOOST_CHECK_EQUAL("Input1", tensors[1]);
251}
252
253
254
255BOOST_AUTO_TEST_SUITE_END()