blob: b97895e849640e457073386ec870de281f08c82f [file] [log] [blame]
Mike Kellyb5fdf382019-06-11 16:35:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "ArmnnDriver.hpp"
9#include "ArmnnDriverImpl.hpp"
10#include "RequestThread.hpp"
11#include "ModelToINetworkConverter.hpp"
12
13#include <NeuralNetworks.h>
14#include <armnn/ArmNN.hpp>
15
16#include <string>
17#include <vector>
18
19namespace armnn_driver
20{
21
Mike Kelly65c42dc2019-07-22 14:06:00 +010022typedef std::function<void(::android::hardware::neuralnetworks::V1_0::ErrorStatus status,
23 std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
24 const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
25 std::string callingFunction)> armnnExecuteCallback_1_2;
26
27struct ArmnnCallback_1_2
28{
29 armnnExecuteCallback_1_2 callback;
30 TimePoint driverStart;
31 MeasureTiming measureTiming;
32};
33
Mike Kellyb5fdf382019-06-11 16:35:25 +010034template <typename HalVersion>
35class ArmnnPreparedModel_1_2 : public V1_2::IPreparedModel
36{
37public:
38 using HalModel = typename V1_2::Model;
39
40 ArmnnPreparedModel_1_2(armnn::NetworkId networkId,
41 armnn::IRuntime* runtime,
42 const HalModel& model,
43 const std::string& requestInputsAndOutputsDumpDir,
44 const bool gpuProfilingEnabled);
45
46 virtual ~ArmnnPreparedModel_1_2();
47
48 virtual Return<ErrorStatus> execute(const Request& request,
Mike Kelly65c42dc2019-07-22 14:06:00 +010049 const sp<V1_0::IExecutionCallback>& callback) override;
Mike Kellyb5fdf382019-06-11 16:35:25 +010050
51 virtual Return<ErrorStatus> execute_1_2(const Request& request, MeasureTiming measure,
52 const sp<V1_2::IExecutionCallback>& callback) override;
53
54 virtual Return<void> executeSynchronously(const Request &request,
55 MeasureTiming measure,
56 V1_2::IPreparedModel::executeSynchronously_cb cb) override;
57
58 virtual Return<void> configureExecutionBurst(
59 const sp<V1_2::IBurstCallback>& callback,
60 const android::hardware::MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
61 const android::hardware::MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
62 configureExecutionBurst_cb cb) override;
63
64 /// execute the graph prepared from the request
65 void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
66 std::shared_ptr<armnn::InputTensors>& pInputTensors,
67 std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
Mike Kelly65c42dc2019-07-22 14:06:00 +010068 ArmnnCallback_1_2 callbackDescriptor);
Mike Kellyb5fdf382019-06-11 16:35:25 +010069
70 /// Executes this model with dummy inputs (e.g. all zeroes).
71 /// \return false on failure, otherwise true
72 bool ExecuteWithDummyInputs();
73
74private:
Mike Kelly65c42dc2019-07-22 14:06:00 +010075 Return <ErrorStatus> Execute(const Request& request,
76 MeasureTiming measureTiming,
77 armnnExecuteCallback_1_2 callback);
Mike Kellyb5fdf382019-06-11 16:35:25 +010078
79 template <typename TensorBindingCollection>
80 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
81
Mike Kelly65c42dc2019-07-22 14:06:00 +010082 armnn::NetworkId m_NetworkId;
83 armnn::IRuntime* m_Runtime;
84 V1_2::Model m_Model;
Mike Kellyb5fdf382019-06-11 16:35:25 +010085 // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
86 // It is specific to this class, so it is declared as static here
Mike Kelly65c42dc2019-07-22 14:06:00 +010087 static RequestThread<ArmnnPreparedModel_1_2, HalVersion, ArmnnCallback_1_2> m_RequestThread;
88 uint32_t m_RequestCount;
89 const std::string& m_RequestInputsAndOutputsDumpDir;
90 const bool m_GpuProfilingEnabled;
Mike Kellyb5fdf382019-06-11 16:35:25 +010091};
92
93}