blob: f6dc9a42b265e719f18b5be10698d9c974247b4a [file] [log] [blame]
Colm Donelan0aef6532023-10-02 17:01:37 +01001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include <armnn/INetwork.hpp>
6#include <armnn/IRuntime.hpp>
7#include <armnnOnnxParser/IOnnxParser.hpp>
8#include <iostream>
9
10int main()
11{
12 // Raw protobuf text for a single layer CONV2D model.
13 std::string m_Prototext = R"(
14 ir_version: 3
15 producer_name: "CNTK"
16 producer_version: "2.5.1"
17 domain: "ai.cntk"
18 model_version: 1
19 graph {
20 name: "CNTKGraph"
21 input {
22 name: "Input"
23 type {
24 tensor_type {
25 elem_type: 1
26 shape {
27 dim {
28 dim_value: 1
29 }
30 dim {
31 dim_value: 1
32 }
33 dim {
34 dim_value: 3
35 }
36 dim {
37 dim_value: 3
38 }
39 }
40 }
41 }
42 }
43 input {
44 name: "Weight"
45 type {
46 tensor_type {
47 elem_type: 1
48 shape {
49 dim {
50 dim_value: 1
51 }
52 dim {
53 dim_value: 1
54 }
55 dim {
56 dim_value: 3
57 }
58 dim {
59 dim_value: 3
60 }
61 }
62 }
63 }
64 }
65 initializer {
66 dims: 1
67 dims: 1
68 dims: 3
69 dims: 3
70 data_type: 1
71 float_data: 2
72 float_data: 1
73 float_data: 0
74 float_data: 6
75 float_data: 2
76 float_data: 1
77 float_data: 4
78 float_data: 1
79 float_data: 2
80 name: "Weight"
81 }
82 node {
83 input: "Input"
84 input: "Weight"
85 output: "Output"
86 name: "Convolution"
87 op_type: "Conv"
88 attribute {
89 name: "kernel_shape"
90 ints: 3
91 ints: 3
92 type: INTS
93 }
94 attribute {
95 name: "strides"
96 ints: 1
97 ints: 1
98 type: INTS
99 }
100 attribute {
101 name: "auto_pad"
102 s: "VALID"
103 type: STRING
104 }
105 attribute {
106 name: "group"
107 i: 1
108 type: INT
109 }
110 attribute {
111 name: "dilations"
112 ints: 1
113 ints: 1
114 type: INTS
115 }
116 doc_string: ""
117 domain: ""
118 }
119 output {
120 name: "Output"
121 type {
122 tensor_type {
123 elem_type: 1
124 shape {
125 dim {
126 dim_value: 1
127 }
128 dim {
129 dim_value: 1
130 }
131 dim {
132 dim_value: 1
133 }
134 dim {
135 dim_value: 1
136 }
137 }
138 }
139 }
140 }
141 }
142 opset_import {
143 version: 7
144 })";
145
146 using namespace armnn;
147
148 // Create ArmNN runtime
149 IRuntime::CreationOptions options; // default options
150 IRuntimePtr runtime = IRuntime::Create(options);
151 // Create the parser.
152 armnnOnnxParser::IOnnxParserPtr parser = armnnOnnxParser::IOnnxParser::Create();
153 try
154 {
155 // Parse the proto text.
156 armnn::INetworkPtr network = parser->CreateNetworkFromString(m_Prototext);
157 auto optimized = Optimize(*network, { armnn::Compute::CpuRef }, runtime->GetDeviceSpec());
158 if (!optimized)
159 {
160 std::cout << "Error: Failed to optimise the input network." << std::endl;
161 return 1;
162 }
163 armnn::NetworkId networkId;
164 std::string errorMsg;
165 Status status = runtime->LoadNetwork(networkId, std::move(optimized), errorMsg);
166 if (status != Status::Success)
167 {
168 std::cout << "Error: Failed to load the optimized network." << std::endl;
169 return -1;
170 }
171
172 // Setup the input and output.
173 std::vector<armnnOnnxParser::BindingPointInfo> inputBindings;
174 // Coz we know the model we know the input tensor is called Input and output is Output.
175 inputBindings.push_back(parser->GetNetworkInputBindingInfo("Input"));
176 std::vector<armnnOnnxParser::BindingPointInfo> outputBindings;
177 outputBindings.push_back(parser->GetNetworkOutputBindingInfo("Output"));
178 // Allocate input tensors
179 armnn::InputTensors inputTensors;
180 std::vector<float> in_data(inputBindings[0].second.GetNumElements());
181 TensorInfo inputTensorInfo(inputBindings[0].second);
182 inputTensorInfo.SetConstant(true);
183 // Set some kind of values in the input.
184 for (int i = 0; i < inputBindings[0].second.GetNumElements(); i++)
185 {
186 in_data[i] = 1.0f + i;
187 }
188 inputTensors.push_back({ inputBindings[0].first, armnn::ConstTensor(inputTensorInfo, in_data.data()) });
189
190 // Allocate output tensors
191 armnn::OutputTensors outputTensors;
192 std::vector<float> out_data(outputBindings[0].second.GetNumElements());
193 outputTensors.push_back({ outputBindings[0].first, armnn::Tensor(outputBindings[0].second, out_data.data()) });
194
195 runtime->EnqueueWorkload(networkId, inputTensors, outputTensors);
196 runtime->UnloadNetwork(networkId);
197 // We're finished with the parser.
198 armnnOnnxParser::IOnnxParser::Destroy(parser.get());
199 parser.release();
200 }
201 catch (const std::exception& e) // Could be an InvalidArgumentException or a ParseException.
202 {
203 std::cout << "Unable to create parser for the passed protobuf string. Reason: " << e.what() << std::endl;
204 return -1;
205 }
206 return 0;
207}