blob: 93a44259f625000aaef09d6446c673b9f34d95cf [file] [log] [blame]
David Monahan8a570462023-11-22 13:24:25 +00001//
2// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include <armnn/backends/ITensorHandleFactory.hpp>
8
9#include <aclCommon/BaseMemoryManager.hpp>
10
11namespace armnn
12{
13
14constexpr const char * GpuFsaTensorHandleFactoryId() { return "Arm/GpuFsa/TensorHandleFactory"; }
15
16class GpuFsaTensorHandleFactory : public ITensorHandleFactory
17{
18
19public:
20 GpuFsaTensorHandleFactory(std::shared_ptr<GpuFsaMemoryManager> mgr)
21 : m_MemoryManager(mgr)
22 {}
23
24 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
25 TensorShape const& subTensorShape,
26 unsigned int const* subTensorOrigin) const override;
27
28 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
29
30 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
31 DataLayout dataLayout) const override;
32
33 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
34 const bool IsMemoryManaged) const override;
35
36 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
37 DataLayout dataLayout,
38 const bool IsMemoryManaged) const override;
39
40 static const FactoryId& GetIdStatic();
41
42 const FactoryId& GetId() const override;
43
44 bool SupportsSubTensors() const override;
45
46 MemorySourceFlags GetExportFlags() const override;
47
48 MemorySourceFlags GetImportFlags() const override;
49
50private:
51 mutable std::shared_ptr<GpuFsaMemoryManager> m_MemoryManager;
52
53};
54
55} // namespace armnn