blob: 431d56bd37aa867565b64b9779aae0de9fb9f7c6 [file] [log] [blame]
Francis Murtaghbf354142022-08-12 13:54:17 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include <armnn/backends/TensorHandle.hpp>
8
9#include "TosaRefMemoryManager.hpp"
10
11namespace armnn
12{
13
14// An implementation of ITensorHandle with simple "bump the pointer" memory-management behaviour
15class TosaRefTensorHandle : public ITensorHandle
16{
17public:
18 TosaRefTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<TosaRefMemoryManager> &memoryManager);
19
20 TosaRefTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags);
21
22 ~TosaRefTensorHandle();
23
24 virtual void Manage() override;
25
26 virtual void Allocate() override;
27
28 virtual ITensorHandle* GetParent() const override
29 {
30 return nullptr;
31 }
32
33 virtual const void* Map(bool /* blocking = true */) const override;
34 using ITensorHandle::Map;
35
36 virtual void Unmap() const override
37 {}
38
39 TensorShape GetStrides() const override
40 {
41 return GetUnpaddedTensorStrides(m_TensorInfo);
42 }
43
44 TensorShape GetShape() const override
45 {
46 return m_TensorInfo.GetShape();
47 }
48
49 const TensorInfo& GetTensorInfo() const
50 {
51 return m_TensorInfo;
52 }
53
54 virtual MemorySourceFlags GetImportFlags() const override
55 {
56 return m_ImportFlags;
57 }
58
59 virtual bool Import(void* memory, MemorySource source) override;
60 virtual bool CanBeImported(void* memory, MemorySource source) override;
61
62private:
63 // Only used for testing
64 void CopyOutTo(void*) const override;
65 void CopyInFrom(const void*) override;
66
67 void* GetPointer() const;
68
69 TosaRefTensorHandle(const TosaRefTensorHandle& other) = delete; // noncopyable
70 TosaRefTensorHandle& operator=(const TosaRefTensorHandle& other) = delete; //noncopyable
71
72 TensorInfo m_TensorInfo;
73
74 std::shared_ptr<TosaRefMemoryManager> m_MemoryManager;
75 TosaRefMemoryManager::Pool* m_Pool;
76 mutable void* m_UnmanagedMemory;
77 MemorySourceFlags m_ImportFlags;
78 bool m_Imported;
79 bool m_IsImportEnabled;
80};
81
82}