blob: 4e883b6b14a126e612dd52228f37e754ae4e0b98 [file] [log] [blame]
Mike Kellyb5fdf382019-06-11 16:35:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "ArmnnDriver.hpp"
9#include "ArmnnDriverImpl.hpp"
10#include "RequestThread.hpp"
11#include "ModelToINetworkConverter.hpp"
12
13#include <NeuralNetworks.h>
14#include <armnn/ArmNN.hpp>
15
16#include <string>
17#include <vector>
18
19namespace armnn_driver
20{
21
22template <typename HalVersion>
23class ArmnnPreparedModel_1_2 : public V1_2::IPreparedModel
24{
25public:
26 using HalModel = typename V1_2::Model;
27
28 ArmnnPreparedModel_1_2(armnn::NetworkId networkId,
29 armnn::IRuntime* runtime,
30 const HalModel& model,
31 const std::string& requestInputsAndOutputsDumpDir,
32 const bool gpuProfilingEnabled);
33
34 virtual ~ArmnnPreparedModel_1_2();
35
36 virtual Return<ErrorStatus> execute(const Request& request,
37 const ::android::sp<V1_0::IExecutionCallback>& callback) override;
38
39 virtual Return<ErrorStatus> execute_1_2(const Request& request, MeasureTiming measure,
40 const sp<V1_2::IExecutionCallback>& callback) override;
41
42 virtual Return<void> executeSynchronously(const Request &request,
43 MeasureTiming measure,
44 V1_2::IPreparedModel::executeSynchronously_cb cb) override;
45
46 virtual Return<void> configureExecutionBurst(
47 const sp<V1_2::IBurstCallback>& callback,
48 const android::hardware::MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
49 const android::hardware::MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
50 configureExecutionBurst_cb cb) override;
51
52 /// execute the graph prepared from the request
53 void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
54 std::shared_ptr<armnn::InputTensors>& pInputTensors,
55 std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
56 const ::android::sp<V1_0::IExecutionCallback>& callback);
57
58 /// Executes this model with dummy inputs (e.g. all zeroes).
59 /// \return false on failure, otherwise true
60 bool ExecuteWithDummyInputs();
61
62private:
63 template <typename ExecutionCallback>
64 Return <ErrorStatus> Execute(const Request &request, const sp <ExecutionCallback> &callback);
65
66 template <typename TensorBindingCollection>
67 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
68
69 armnn::NetworkId m_NetworkId;
70 armnn::IRuntime* m_Runtime;
71 V1_2::Model m_Model;
72 // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
73 // It is specific to this class, so it is declared as static here
74 static RequestThread<ArmnnPreparedModel_1_2, HalVersion> m_RequestThread;
75 uint32_t m_RequestCount;
76 const std::string& m_RequestInputsAndOutputsDumpDir;
77 const bool m_GpuProfilingEnabled;
78};
79
80}