blob: 17dd1d4d063cc3db905faf77d6fc512f9e0c6642 [file] [log] [blame]
Francis Murtagh9270d9e2022-08-12 13:54:17 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TosaRefMemoryManager.hpp"
9
10#include <armnn/backends/ITensorHandleFactory.hpp>
11
12namespace armnn
13{
14
15constexpr const char * TosaRefTensorHandleFactoryId() { return "Arm/TosaRef/TensorHandleFactory"; }
16
17class TosaRefTensorHandleFactory : public ITensorHandleFactory
18{
19
20public:
21 TosaRefTensorHandleFactory(std::shared_ptr<TosaRefMemoryManager> mgr)
22 : m_MemoryManager(mgr)
23 , m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc))
24 , m_ExportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc))
25 {}
26
27 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
28 TensorShape const& subTensorShape,
29 unsigned int const* subTensorOrigin) const override;
30
31 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
32
33 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
34 DataLayout dataLayout) const override;
35
36 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
37 const bool IsMemoryManaged) const override;
38
39 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
40 DataLayout dataLayout,
41 const bool IsMemoryManaged) const override;
42
43 static const FactoryId& GetIdStatic();
44
45 const FactoryId& GetId() const override;
46
47 bool SupportsSubTensors() const override;
48
49 MemorySourceFlags GetExportFlags() const override;
50
51 MemorySourceFlags GetImportFlags() const override;
52
53private:
54 mutable std::shared_ptr<TosaRefMemoryManager> m_MemoryManager;
55 MemorySourceFlags m_ImportFlags;
56 MemorySourceFlags m_ExportFlags;
57
58};
59
60} // namespace armnn
61