blob: 9a7518b21a277b77db56d700036204e70c3910ae [file] [log] [blame]
Colm Donelan17948b52022-02-01 23:37:04 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "MockMemoryManager.hpp"
8#include <armnn/backends/TensorHandle.hpp>
9
10namespace armnn
11{
12
13// An implementation of ITensorHandle with simple "bump the pointer" memory-management behaviour
14class MockTensorHandle : public ITensorHandle
15{
16public:
17 MockTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<MockMemoryManager>& memoryManager);
18
19 MockTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags);
20
21 ~MockTensorHandle() override;
22
23 void Manage() override;
24
25 void Allocate() override;
26
27 ITensorHandle* GetParent() const override
28 {
29 return nullptr;
30 }
31
32 const void* Map(bool /* blocking = true */) const override;
33 using ITensorHandle::Map;
34
35 void Unmap() const override
36 {}
37
38 TensorShape GetStrides() const override
39 {
40 return GetUnpaddedTensorStrides(m_TensorInfo);
41 }
42
43 TensorShape GetShape() const override
44 {
45 return m_TensorInfo.GetShape();
46 }
47
48 const TensorInfo& GetTensorInfo() const
49 {
50 return m_TensorInfo;
51 }
52
53 MemorySourceFlags GetImportFlags() const override
54 {
55 return m_ImportFlags;
56 }
57
58 bool Import(void* memory, MemorySource source) override;
59 bool CanBeImported(void* memory, MemorySource source) override;
60
61private:
62 // Only used for testing
63 void CopyOutTo(void*) const override;
64 void CopyInFrom(const void*) override;
65
66 void* GetPointer() const;
67
68 MockTensorHandle(const MockTensorHandle& other) = delete; // noncopyable
69 MockTensorHandle& operator=(const MockTensorHandle& other) = delete; //noncopyable
70
71 TensorInfo m_TensorInfo;
72
73 std::shared_ptr<MockMemoryManager> m_MemoryManager;
74 MockMemoryManager::Pool* m_Pool;
75 mutable void* m_UnmanagedMemory;
76 MemorySourceFlags m_ImportFlags;
77 bool m_Imported;
78 bool m_IsImportEnabled;
79};
80
81} // namespace armnn