blob: ac06cba4104c160f9ae29749eef2c2a6ea5c8bab [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
Sadik Armagan1625efc2021-06-10 18:24:34 +01005
telsoa01c577f2c2018-08-31 09:22:23 +01006#include "../OnnxParser.hpp"
7#include "ParserPrototxtFixture.hpp"
8#include <onnx/onnx.pb.h>
9#include "google/protobuf/stubs/logging.h"
10
telsoa01c577f2c2018-08-31 09:22:23 +010011using ModelPtr = std::unique_ptr<onnx::ModelProto>;
12
Sadik Armagan1625efc2021-06-10 18:24:34 +010013TEST_SUITE("OnnxParser_GetInputsOutputs")
14{
telsoa01c577f2c2018-08-31 09:22:23 +010015struct GetInputsOutputsMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
16{
17 explicit GetInputsOutputsMainFixture()
18 {
19 m_Prototext = R"(
20 ir_version: 3
21 producer_name: "CNTK"
22 producer_version: "2.5.1"
23 domain: "ai.cntk"
24 model_version: 1
25 graph {
26 name: "CNTKGraph"
27 input {
28 name: "Input"
29 type {
30 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000031 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010032 shape {
33 dim {
34 dim_value: 4
35 }
36 }
37 }
38 }
39 }
40 node {
41 input: "Input"
42 output: "Output"
43 name: "ActivationLayer"
44 op_type: "Relu"
45 }
46 output {
47 name: "Output"
48 type {
49 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000050 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010051 shape {
52 dim {
53 dim_value: 4
54 }
55 }
56 }
57 }
58 }
59 }
60 opset_import {
61 version: 7
62 })";
63 Setup();
64 }
65};
66
67
Sadik Armagan1625efc2021-06-10 18:24:34 +010068TEST_CASE_FIXTURE(GetInputsOutputsMainFixture, "GetInput")
telsoa01c577f2c2018-08-31 09:22:23 +010069{
Kevin Mayef33cb12021-01-29 14:24:57 +000070 ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str());
71 std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model);
Sadik Armagan1625efc2021-06-10 18:24:34 +010072 CHECK_EQ(1, tensors.size());
73 CHECK_EQ("Input", tensors[0]);
telsoa01c577f2c2018-08-31 09:22:23 +010074
75}
76
Sadik Armagan1625efc2021-06-10 18:24:34 +010077TEST_CASE_FIXTURE(GetInputsOutputsMainFixture, "GetOutput")
telsoa01c577f2c2018-08-31 09:22:23 +010078{
Kevin Mayef33cb12021-01-29 14:24:57 +000079 ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str());
80 std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetOutputs(model);
Sadik Armagan1625efc2021-06-10 18:24:34 +010081 CHECK_EQ(1, tensors.size());
82 CHECK_EQ("Output", tensors[0]);
telsoa01c577f2c2018-08-31 09:22:23 +010083}
84
85struct GetEmptyInputsOutputsFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
86{
87 GetEmptyInputsOutputsFixture()
88 {
89 m_Prototext = R"(
90 ir_version: 3
91 producer_name: "CNTK "
92 producer_version: "2.5.1 "
93 domain: "ai.cntk "
94 model_version: 1
95 graph {
96 name: "CNTKGraph "
97 node {
98 output: "Output"
99 attribute {
100 name: "value"
101 t {
102 dims: 7
Matteo Martincigh44a71672018-12-11 13:46:52 +0000103 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100104 float_data: 0.0
105 float_data: 1.0
106 float_data: 2.0
107 float_data: 3.0
108 float_data: 4.0
109 float_data: 5.0
110 float_data: 6.0
111
112 }
Matteo Martincigh44a71672018-12-11 13:46:52 +0000113 type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100114 }
115 name: "constantNode"
116 op_type: "Constant"
117 }
118 output {
119 name: "Output"
120 type {
121 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000122 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100123 shape {
124 dim {
125 dim_value: 7
126 }
127 }
128 }
129 }
130 }
131 }
132 opset_import {
133 version: 7
134 })";
135 Setup();
136 }
137};
138
Sadik Armagan1625efc2021-06-10 18:24:34 +0100139TEST_CASE_FIXTURE(GetEmptyInputsOutputsFixture, "GetEmptyInputs")
telsoa01c577f2c2018-08-31 09:22:23 +0100140{
Kevin Mayef33cb12021-01-29 14:24:57 +0000141 ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str());
142 std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100143 CHECK_EQ(0, tensors.size());
telsoa01c577f2c2018-08-31 09:22:23 +0100144}
145
Sadik Armagan1625efc2021-06-10 18:24:34 +0100146TEST_CASE("GetInputsNullModel")
telsoa01c577f2c2018-08-31 09:22:23 +0100147{
Sadik Armagan1625efc2021-06-10 18:24:34 +0100148 CHECK_THROWS_AS(armnnOnnxParser::OnnxParserImpl::LoadModelFromString(""), armnn::InvalidArgumentException);
telsoa01c577f2c2018-08-31 09:22:23 +0100149}
150
Sadik Armagan1625efc2021-06-10 18:24:34 +0100151TEST_CASE("GetOutputsNullModel")
telsoa01c577f2c2018-08-31 09:22:23 +0100152{
153 auto silencer = google::protobuf::LogSilencer(); //get rid of errors from protobuf
Sadik Armagan1625efc2021-06-10 18:24:34 +0100154 CHECK_THROWS_AS(armnnOnnxParser::OnnxParserImpl::LoadModelFromString("nknnk"), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100155}
156
157struct GetInputsMultipleFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
158{
159 GetInputsMultipleFixture() {
160
161 m_Prototext = R"(
162 ir_version: 3
163 producer_name: "CNTK"
164 producer_version: "2.5.1"
165 domain: "ai.cntk"
166 model_version: 1
167 graph {
168 name: "CNTKGraph"
169 input {
170 name: "Input0"
171 type {
172 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000173 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100174 shape {
175 dim {
176 dim_value: 1
177 }
178 dim {
179 dim_value: 1
180 }
181 dim {
182 dim_value: 1
183 }
184 dim {
185 dim_value: 4
186 }
187 }
188 }
189 }
190 }
191 input {
192 name: "Input1"
193 type {
194 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000195 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100196 shape {
197 dim {
198 dim_value: 4
199 }
200 }
201 }
202 }
203 }
204 node {
205 input: "Input0"
206 input: "Input1"
207 output: "Output"
208 name: "addition"
209 op_type: "Add"
210 doc_string: ""
211 domain: ""
212 }
213 output {
214 name: "Output"
215 type {
216 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000217 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100218 shape {
219 dim {
220 dim_value: 1
221 }
222 dim {
223 dim_value: 1
224 }
225 dim {
226 dim_value: 1
227 }
228 dim {
229 dim_value: 4
230 }
231 }
232 }
233 }
234 }
235 }
236 opset_import {
237 version: 7
238 })";
239 Setup();
240 }
241};
242
Sadik Armagan1625efc2021-06-10 18:24:34 +0100243TEST_CASE_FIXTURE(GetInputsMultipleFixture, "GetInputsMultipleInputs")
telsoa01c577f2c2018-08-31 09:22:23 +0100244{
Kevin Mayef33cb12021-01-29 14:24:57 +0000245 ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str());
246 std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100247 CHECK_EQ(2, tensors.size());
248 CHECK_EQ("Input0", tensors[0]);
249 CHECK_EQ("Input1", tensors[1]);
telsoa01c577f2c2018-08-31 09:22:23 +0100250}
251
Sadik Armagan1625efc2021-06-10 18:24:34 +0100252}