blob: 1c74741e09cc990c03a990a054c48878dd739193 [file] [log] [blame]
Colm Donelan0aef6532023-10-02 17:01:37 +01001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include <armnn/INetwork.hpp>
6#include <armnn/IRuntime.hpp>
7#include <armnnTfLiteParser/ITfLiteParser.hpp>
8
9#include <iostream>
10
11int main()
12{
13 using namespace armnn;
14
15 // Create ArmNN runtime
16 IRuntime::CreationOptions options; // default options
17 IRuntimePtr runtime = IRuntime::Create(options);
18 // Parse a TfLite file.
19 armnnTfLiteParser::ITfLiteParserPtr parser = armnnTfLiteParser::ITfLiteParser::Create();
20 try
21 {
22 INetworkPtr myNetwork = parser->CreateNetworkFromBinaryFile("./simple_conv2d_1_op.tflite");
23 // Optimise ArmNN network
24 IOptimizedNetworkPtr optNet = Optimize(*myNetwork, { Compute::CpuRef }, runtime->GetDeviceSpec());
25 if (!optNet)
26 {
27 std::cout << "Error: Failed to optimise the input network." << std::endl;
28 return 1;
29 }
30 NetworkId networkId;
31 // Load graph into runtime
32 Status loaded = runtime->LoadNetwork(networkId, std::move(optNet));
33 if (loaded != Status::Success)
34 {
35 std::cout << "Error: Failed to load the optimized network." << std::endl;
36 return 1;
37 }
38
39 // Setup the input and output.
40 std::vector<armnnTfLiteParser::BindingPointInfo> inputBindings;
41 std::vector<std::string> inputTensorNames = parser->GetSubgraphInputTensorNames(0);
42 inputBindings.push_back(parser->GetNetworkInputBindingInfo(0, inputTensorNames[0]));
43
44 std::vector<armnnTfLiteParser::BindingPointInfo> outputBindings;
45 std::vector<std::string> outputTensorNames = parser->GetSubgraphOutputTensorNames(0);
46 outputBindings.push_back(parser->GetNetworkOutputBindingInfo(0, outputTensorNames[0]));
47 TensorInfo inputTensorInfo(inputBindings[0].second);
48 inputTensorInfo.SetConstant(true);
49
50 // Allocate input tensors
51 armnn::InputTensors inputTensors;
52 std::vector<float> in_data(inputBindings[0].second.GetNumElements());
53 // Set some kind of values in the input.
54 for (int i = 0; i < inputBindings[0].second.GetNumElements(); i++)
55 {
56 in_data[i] = 1.0f + i;
57 }
58 inputTensors.push_back({ inputBindings[0].first, armnn::ConstTensor(inputTensorInfo, in_data.data()) });
59
60 // Allocate output tensors
61 armnn::OutputTensors outputTensors;
62 std::vector<float> out_data(outputBindings[0].second.GetNumElements());
63 outputTensors.push_back({ outputBindings[0].first, armnn::Tensor(outputBindings[0].second, out_data.data()) });
64
65 runtime->EnqueueWorkload(networkId, inputTensors, outputTensors);
66 runtime->UnloadNetwork(networkId);
67 // We're finished with the parser.
68 armnnTfLiteParser::ITfLiteParser::Destroy(parser.get());
69 parser.release();
70 }
71 catch (const std::exception& e) // Could be: InvalidArgumentException, ParseException or a FileNotFoundException.
72 {
73 std::cout << "Unable to create parser for \"./simple_conv2d_1_op.tflite\". Reason: " << e.what() << std::endl;
74 return -1;
75 }
76
77 return 0;
78}