blob: f7644b95426f5d8d0b448272479f91dd2d90aad7 [file] [log] [blame]
telsoa015307bc12018-03-09 13:51:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#pragma once
7
8#include "RequestThread.hpp"
9
10#include "HalInterfaces.h"
11#include "NeuralNetworks.h"
12#include <armnn/ArmNN.hpp>
13
surmeh01deb3bdb2018-07-05 12:06:04 +010014#include "ArmnnDriver.hpp"
15
telsoa015307bc12018-03-09 13:51:08 +000016#include <string>
17#include <vector>
18
19namespace armnn_driver
20{
21
22class ArmnnPreparedModel : public IPreparedModel
23{
24public:
25 ArmnnPreparedModel(armnn::NetworkId networkId,
26 armnn::IRuntime* runtime,
surmeh01deb3bdb2018-07-05 12:06:04 +010027 const V1_0::Model& model,
telsoa015307bc12018-03-09 13:51:08 +000028 const std::string& requestInputsAndOutputsDumpDir);
29
30 virtual ~ArmnnPreparedModel();
31
32 virtual Return<ErrorStatus> execute(const Request& request,
33 const ::android::sp<IExecutionCallback>& callback) override;
34
35 /// execute the graph prepared from the request
36 void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
37 std::shared_ptr<armnn::InputTensors>& pInputTensors,
38 std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
39 const ::android::sp<IExecutionCallback>& callback);
40
41 /// Executes this model with dummy inputs (e.g. all zeroes).
42 void ExecuteWithDummyInputs();
43
44private:
45
46 template <typename TensorBindingCollection>
47 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
48
49 armnn::NetworkId m_NetworkId;
50 armnn::IRuntime* m_Runtime;
surmeh01deb3bdb2018-07-05 12:06:04 +010051 V1_0::Model m_Model;
telsoa015307bc12018-03-09 13:51:08 +000052 // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
53 // It is specific to this class, so it is declared as static here
54 static RequestThread m_RequestThread;
55 uint32_t m_RequestCount;
56 const std::string& m_RequestInputsAndOutputsDumpDir;
57};
58
59class AndroidNnCpuExecutorPreparedModel : public IPreparedModel
60{
61public:
62
surmeh01deb3bdb2018-07-05 12:06:04 +010063 AndroidNnCpuExecutorPreparedModel(const V1_0::Model& model, const std::string& requestInputsAndOutputsDumpDir);
telsoa015307bc12018-03-09 13:51:08 +000064 virtual ~AndroidNnCpuExecutorPreparedModel() { }
65
66 bool Initialize();
67
68 virtual Return<ErrorStatus> execute(const Request& request,
69 const ::android::sp<IExecutionCallback>& callback) override;
70
71private:
72
73 void DumpTensorsIfRequired(
74 char const* tensorNamePrefix,
75 const hidl_vec<uint32_t>& operandIndices,
76 const hidl_vec<RequestArgument>& requestArgs,
77 const std::vector<android::nn::RunTimePoolInfo>& requestPoolInfos);
78
surmeh01deb3bdb2018-07-05 12:06:04 +010079 V1_0::Model m_Model;
telsoa015307bc12018-03-09 13:51:08 +000080 std::vector<android::nn::RunTimePoolInfo> m_ModelPoolInfos;
81 const std::string& m_RequestInputsAndOutputsDumpDir;
82 uint32_t m_RequestCount;
83};
84
85}