blob: e68614a0650c29cab93fe36a5ad36be4901b0947 [file] [log] [blame]
Mike Kellyb5fdf382019-06-11 16:35:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "ArmnnDriver.hpp"
9#include "ArmnnDriverImpl.hpp"
10#include "RequestThread.hpp"
11#include "ModelToINetworkConverter.hpp"
12
13#include <NeuralNetworks.h>
14#include <armnn/ArmNN.hpp>
15
16#include <string>
17#include <vector>
18
19namespace armnn_driver
20{
21
Derek Lamberti4de83c52020-03-17 13:40:18 +000022using CallbackAsync_1_2 = std::function<
23 void(V1_0::ErrorStatus errorStatus,
24 std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
25 const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
26 std::string callingFunction)>;
Mike Kelly65c42dc2019-07-22 14:06:00 +010027
Derek Lamberti4de83c52020-03-17 13:40:18 +000028struct ExecutionContext_1_2
Mike Kelly65c42dc2019-07-22 14:06:00 +010029{
Derek Lamberti4de83c52020-03-17 13:40:18 +000030 ::android::hardware::neuralnetworks::V1_2::MeasureTiming measureTimings =
31 ::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO;
Mike Kelly65c42dc2019-07-22 14:06:00 +010032 TimePoint driverStart;
Mike Kelly65c42dc2019-07-22 14:06:00 +010033};
34
Derek Lamberti4de83c52020-03-17 13:40:18 +000035using CallbackContext_1_2 = CallbackContext<CallbackAsync_1_2, ExecutionContext_1_2>;
36
Mike Kellyb5fdf382019-06-11 16:35:25 +010037template <typename HalVersion>
38class ArmnnPreparedModel_1_2 : public V1_2::IPreparedModel
39{
40public:
41 using HalModel = typename V1_2::Model;
42
43 ArmnnPreparedModel_1_2(armnn::NetworkId networkId,
44 armnn::IRuntime* runtime,
45 const HalModel& model,
46 const std::string& requestInputsAndOutputsDumpDir,
47 const bool gpuProfilingEnabled);
48
49 virtual ~ArmnnPreparedModel_1_2();
50
Kevin Mayec1e5b82020-02-26 17:00:39 +000051 virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
52 const sp<V1_0::IExecutionCallback>& callback) override;
Mike Kellyb5fdf382019-06-11 16:35:25 +010053
Kevin Mayec1e5b82020-02-26 17:00:39 +000054 virtual Return<V1_0::ErrorStatus> execute_1_2(const V1_0::Request& request, MeasureTiming measure,
55 const sp<V1_2::IExecutionCallback>& callback) override;
Mike Kellyb5fdf382019-06-11 16:35:25 +010056
Kevin Mayec1e5b82020-02-26 17:00:39 +000057 virtual Return<void> executeSynchronously(const V1_0::Request &request,
Mike Kellyb5fdf382019-06-11 16:35:25 +010058 MeasureTiming measure,
59 V1_2::IPreparedModel::executeSynchronously_cb cb) override;
60
61 virtual Return<void> configureExecutionBurst(
62 const sp<V1_2::IBurstCallback>& callback,
63 const android::hardware::MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
64 const android::hardware::MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
65 configureExecutionBurst_cb cb) override;
66
67 /// execute the graph prepared from the request
Derek Lamberti4de83c52020-03-17 13:40:18 +000068 template<typename CallbackContext>
69 bool ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
70 armnn::InputTensors& inputTensors,
71 armnn::OutputTensors& outputTensors,
72 CallbackContext callback);
Mike Kellyb5fdf382019-06-11 16:35:25 +010073
74 /// Executes this model with dummy inputs (e.g. all zeroes).
75 /// \return false on failure, otherwise true
76 bool ExecuteWithDummyInputs();
77
78private:
Derek Lamberti4de83c52020-03-17 13:40:18 +000079 Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request,
80 MeasureTiming measureTiming,
81 CallbackAsync_1_2 callback);
82
83 Return<V1_0::ErrorStatus> PrepareMemoryForInputs(
84 armnn::InputTensors& inputs,
85 const V1_0::Request& request,
86 const std::vector<android::nn::RunTimePoolInfo>& memPools);
87
88 Return<V1_0::ErrorStatus> PrepareMemoryForOutputs(
89 armnn::OutputTensors& outputs,
90 std::vector<OutputShape> &outputShapes,
91 const V1_0::Request& request,
92 const std::vector<android::nn::RunTimePoolInfo>& memPools);
93
94 Return <V1_0::ErrorStatus> PrepareMemoryForIO(
95 armnn::InputTensors& inputs,
96 armnn::OutputTensors& outputs,
97 std::vector<android::nn::RunTimePoolInfo>& memPools,
98 const V1_0::Request& request,
99 CallbackAsync_1_2 callback);
Mike Kellyb5fdf382019-06-11 16:35:25 +0100100
101 template <typename TensorBindingCollection>
102 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
103
Mike Kelly65c42dc2019-07-22 14:06:00 +0100104 armnn::NetworkId m_NetworkId;
105 armnn::IRuntime* m_Runtime;
106 V1_2::Model m_Model;
Mike Kellyb5fdf382019-06-11 16:35:25 +0100107 // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
108 // It is specific to this class, so it is declared as static here
Derek Lamberti4de83c52020-03-17 13:40:18 +0000109 static RequestThread<ArmnnPreparedModel_1_2,
110 HalVersion,
111 CallbackContext_1_2> m_RequestThread;
Mike Kelly65c42dc2019-07-22 14:06:00 +0100112 uint32_t m_RequestCount;
113 const std::string& m_RequestInputsAndOutputsDumpDir;
114 const bool m_GpuProfilingEnabled;
Mike Kellyb5fdf382019-06-11 16:35:25 +0100115};
116
117}