blob: ed95577cca431f5d42756cb601c85687e8e840a4 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#pragma once
6
7#include "backends/WorkloadDataFwd.hpp"
8
9#include <string>
10#include <vector>
11
12#include <memory>
13#include <set>
14
15#include <boost/assert.hpp>
16
17#include "armnn/INetwork.hpp"
18#include "armnn/Types.hpp"
19#include "armnn/Descriptors.hpp"
20#include "armnn/Tensor.hpp"
21#include "ITensorHandle.hpp"
22
23namespace armnn
24{
25
26class ITensorHandle;
27class IWorkloadFactory;
28class OutputSlot;
29class WorkloadDataCollector;
30
31class OutputHandler
32{
33public:
telsoa01c577f2c2018-08-31 09:22:23 +010034 /// @brief - Sets the TensorInfo used by this output handler.
35 /// @param tensorInfo - TensorInfo for the output.
telsoa014fcda012018-03-09 14:13:49 +000036 void SetTensorInfo(const TensorInfo& tensorInfo);
37
telsoa01c577f2c2018-08-31 09:22:23 +010038 /// @brief - Creates tensor handlers used by the intermediate tensors. Does not allocate memory.
39 /// @param factory - Factory to be used for handler creation.
telsoa014fcda012018-03-09 14:13:49 +000040 void CreateTensorHandles(const IWorkloadFactory& factory);
41
telsoa01c577f2c2018-08-31 09:22:23 +010042 /// @brief - Gets the matching TensorInfo for the output.
43 /// @return - References to the output TensorInfo.
telsoa014fcda012018-03-09 14:13:49 +000044 const TensorInfo& GetTensorInfo() const { return m_TensorInfo; }
45
telsoa01c577f2c2018-08-31 09:22:23 +010046 /// @brief - Gets the allocated tensor memory.
47 /// @return - Pointer to the tensor memory.
telsoa014fcda012018-03-09 14:13:49 +000048 ITensorHandle* GetData() const { return m_TensorHandle.get(); }
49
telsoa01c577f2c2018-08-31 09:22:23 +010050 /// Fill the outputs for a given queue descriptor.
telsoa014fcda012018-03-09 14:13:49 +000051 void CollectWorkloadOutputs(WorkloadDataCollector& dataCollector) const;
52
53 void SetData(std::unique_ptr<ITensorHandle> data) { m_TensorHandle = std::move(data); }
54
telsoa014fcda012018-03-09 14:13:49 +000055 /// @brief Returns true if SetTensorInfo() has been called at least once on this.
56 bool IsTensorInfoSet() const { return m_bTensorInfoSet; }
57private:
58 std::unique_ptr<ITensorHandle> m_TensorHandle;
59 TensorInfo m_TensorInfo;
60 bool m_bTensorInfoSet = false;
61};
62
63} //namespace armnn