blob: 6cde3263a0a586cb74b9626500b7f36e30bf0ba3 [file] [log] [blame]
Matthew Bentham4cefc412019-06-18 16:14:34 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include <backendsCommon/CpuTensorHandle.hpp>
8
Matthew Bentham7c1603a2019-06-21 17:22:23 +01009#include "RefMemoryManager.hpp"
10
Matthew Bentham4cefc412019-06-18 16:14:34 +010011namespace armnn
12{
13
14// An implementation of ITensorHandle with simple "bump the pointer" memory-management behaviour
15class RefTensorHandle : public ITensorHandle
16{
17public:
Matthew Bentham7c1603a2019-06-21 17:22:23 +010018 RefTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<RefMemoryManager> &memoryManager);
Matthew Bentham4cefc412019-06-18 16:14:34 +010019
Ferran Balaguerbfeb2712019-08-07 15:14:56 +010020 RefTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<RefMemoryManager> &memoryManager,
21 MemorySourceFlags importFlags);
22
Matthew Bentham4cefc412019-06-18 16:14:34 +010023 ~RefTensorHandle();
24
Matthew Bentham7c1603a2019-06-21 17:22:23 +010025 virtual void Manage() override;
26
27 virtual void Allocate() override;
Matthew Bentham4cefc412019-06-18 16:14:34 +010028
29 virtual ITensorHandle* GetParent() const override
30 {
31 return nullptr;
32 }
33
Matthew Bentham7c1603a2019-06-21 17:22:23 +010034 virtual const void* Map(bool /* blocking = true */) const override;
35 using ITensorHandle::Map;
Matthew Bentham4cefc412019-06-18 16:14:34 +010036
37 virtual void Unmap() const override
38 {}
39
Matthew Bentham4cefc412019-06-18 16:14:34 +010040 TensorShape GetStrides() const override
41 {
42 return GetUnpaddedTensorStrides(m_TensorInfo);
43 }
44
45 TensorShape GetShape() const override
46 {
47 return m_TensorInfo.GetShape();
48 }
49
50 const TensorInfo& GetTensorInfo() const
51 {
52 return m_TensorInfo;
53 }
54
Ferran Balaguerbfeb2712019-08-07 15:14:56 +010055 virtual MemorySourceFlags GetImportFlags() const override
56 {
57 return m_ImportFlags;
58 }
59
60 virtual bool Import(void* memory, MemorySource source) override;
61
Matthew Bentham4cefc412019-06-18 16:14:34 +010062private:
63 // Only used for testing
64 void CopyOutTo(void*) const override;
65 void CopyInFrom(const void*) override;
66
Matthew Bentham7c1603a2019-06-21 17:22:23 +010067 void* GetPointer() const;
Matthew Bentham4cefc412019-06-18 16:14:34 +010068
Matthew Bentham7c1603a2019-06-21 17:22:23 +010069 RefTensorHandle(const RefTensorHandle& other) = delete; // noncopyable
70 RefTensorHandle& operator=(const RefTensorHandle& other) = delete; //noncopyable
Matthew Bentham4cefc412019-06-18 16:14:34 +010071
72 TensorInfo m_TensorInfo;
Matthew Bentham7c1603a2019-06-21 17:22:23 +010073
74 std::shared_ptr<RefMemoryManager> m_MemoryManager;
75 RefMemoryManager::Pool* m_Pool;
76 mutable void *m_UnmanagedMemory;
Ferran Balaguerbfeb2712019-08-07 15:14:56 +010077 MemorySourceFlags m_ImportFlags;
78 bool m_Imported;
Matthew Bentham4cefc412019-06-18 16:14:34 +010079};
80
Matthew Bentham7c1603a2019-06-21 17:22:23 +010081}