blob: ed7c0bfb08572ff596ba5d9a9c593a628f1c8767 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#include <iostream>
6#include "armnn/ArmNN.hpp"
7
8/// A simple example of using the ArmNN SDK API. In this sample, the users single input number is multiplied by 1.0f
9/// using a fully connected layer with a single neuron to produce an output number that is the same as the input.
10int main()
11{
12 using namespace armnn;
13
14 float number;
15 std::cout << "Please enter a number: " << std::endl;
16 std::cin >> number;
17
18 // Construct ArmNN network
19 armnn::NetworkId networkIdentifier;
20 INetworkPtr myNetwork = INetwork::Create();
21
22 armnn::FullyConnectedDescriptor fullyConnectedDesc;
23 float weightsData[] = {1.0f}; // Identity
24 TensorInfo weightsInfo(TensorShape({1, 1}), DataType::Float32);
25 armnn::ConstTensor weights(weightsInfo, weightsData);
Matteo Martincighfc598e12019-05-14 10:36:13 +010026 IConnectableLayer *fullyConnected = myNetwork->AddFullyConnectedLayer(fullyConnectedDesc,
27 weights,
28 EmptyOptional(),
telsoa01c577f2c2018-08-31 09:22:23 +010029 "fully connected");
30
31 IConnectableLayer *InputLayer = myNetwork->AddInputLayer(0);
32 IConnectableLayer *OutputLayer = myNetwork->AddOutputLayer(0);
33
34 InputLayer->GetOutputSlot(0).Connect(fullyConnected->GetInputSlot(0));
35 fullyConnected->GetOutputSlot(0).Connect(OutputLayer->GetInputSlot(0));
36
37 // Create ArmNN runtime
38 IRuntime::CreationOptions options; // default options
39 IRuntimePtr run = IRuntime::Create(options);
40
41 //Set the tensors in the network.
42 TensorInfo inputTensorInfo(TensorShape({1, 1}), DataType::Float32);
43 InputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
44
45 TensorInfo outputTensorInfo(TensorShape({1, 1}), DataType::Float32);
46 fullyConnected->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
47
48 // Optimise ArmNN network
49 armnn::IOptimizedNetworkPtr optNet = Optimize(*myNetwork, {Compute::CpuRef}, run->GetDeviceSpec());
50
51 // Load graph into runtime
52 run->LoadNetwork(networkIdentifier, std::move(optNet));
53
54 //Creates structures for inputs and outputs.
55 std::vector<float> inputData{number};
56 std::vector<float> outputData(1);
57
58
59 armnn::InputTensors inputTensors{{0, armnn::ConstTensor(run->GetInputTensorInfo(networkIdentifier, 0),
60 inputData.data())}};
61 armnn::OutputTensors outputTensors{{0, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0),
62 outputData.data())}};
63
64 // Execute network
65 run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
66
67 std::cout << "Your number was " << outputData[0] << std::endl;
68 return 0;
69
70}