blob: f61d56ceff96a3535a0d6bf27e9a20efd7cf6d10 [file] [log] [blame]
telsoa015307bc12018-03-09 13:51:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#pragma once
7
8#include "RequestThread.hpp"
9
10#include "HalInterfaces.h"
11#include "NeuralNetworks.h"
12#include <armnn/ArmNN.hpp>
13
14#include <string>
15#include <vector>
16
17namespace armnn_driver
18{
19
20class ArmnnPreparedModel : public IPreparedModel
21{
22public:
23 ArmnnPreparedModel(armnn::NetworkId networkId,
24 armnn::IRuntime* runtime,
25 const Model& model,
26 const std::string& requestInputsAndOutputsDumpDir);
27
28 virtual ~ArmnnPreparedModel();
29
30 virtual Return<ErrorStatus> execute(const Request& request,
31 const ::android::sp<IExecutionCallback>& callback) override;
32
33 /// execute the graph prepared from the request
34 void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
35 std::shared_ptr<armnn::InputTensors>& pInputTensors,
36 std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
37 const ::android::sp<IExecutionCallback>& callback);
38
39 /// Executes this model with dummy inputs (e.g. all zeroes).
40 void ExecuteWithDummyInputs();
41
42private:
43
44 template <typename TensorBindingCollection>
45 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
46
47 armnn::NetworkId m_NetworkId;
48 armnn::IRuntime* m_Runtime;
49 Model m_Model;
50 // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
51 // It is specific to this class, so it is declared as static here
52 static RequestThread m_RequestThread;
53 uint32_t m_RequestCount;
54 const std::string& m_RequestInputsAndOutputsDumpDir;
55};
56
57class AndroidNnCpuExecutorPreparedModel : public IPreparedModel
58{
59public:
60
61 AndroidNnCpuExecutorPreparedModel(const Model& model, const std::string& requestInputsAndOutputsDumpDir);
62 virtual ~AndroidNnCpuExecutorPreparedModel() { }
63
64 bool Initialize();
65
66 virtual Return<ErrorStatus> execute(const Request& request,
67 const ::android::sp<IExecutionCallback>& callback) override;
68
69private:
70
71 void DumpTensorsIfRequired(
72 char const* tensorNamePrefix,
73 const hidl_vec<uint32_t>& operandIndices,
74 const hidl_vec<RequestArgument>& requestArgs,
75 const std::vector<android::nn::RunTimePoolInfo>& requestPoolInfos);
76
77 Model m_Model;
78 std::vector<android::nn::RunTimePoolInfo> m_ModelPoolInfos;
79 const std::string& m_RequestInputsAndOutputsDumpDir;
80 uint32_t m_RequestCount;
81};
82
83}