blob: bc2f7d7e3b402143a4713caf11ffadc7e1d7dfd8 [file] [log] [blame]
Colm Donelan17948b52022-02-01 23:37:04 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "MockMemoryManager.hpp"
Colm Donelana98e79a2022-12-06 21:32:29 +00008
Colm Donelan17948b52022-02-01 23:37:04 +00009#include <armnn/backends/TensorHandle.hpp>
Colm Donelana98e79a2022-12-06 21:32:29 +000010#include <armnn/MemorySources.hpp>
11#include <armnn/Tensor.hpp>
12#include <armnn/Types.hpp>
13#include <armnn/backends/ITensorHandle.hpp>
14#include <memory>
Colm Donelan17948b52022-02-01 23:37:04 +000015
16namespace armnn
17{
18
19// An implementation of ITensorHandle with simple "bump the pointer" memory-management behaviour
20class MockTensorHandle : public ITensorHandle
21{
22public:
23 MockTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<MockMemoryManager>& memoryManager);
24
25 MockTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags);
26
27 ~MockTensorHandle() override;
28
29 void Manage() override;
30
31 void Allocate() override;
32
33 ITensorHandle* GetParent() const override
34 {
35 return nullptr;
36 }
37
38 const void* Map(bool /* blocking = true */) const override;
39 using ITensorHandle::Map;
40
41 void Unmap() const override
42 {}
43
44 TensorShape GetStrides() const override
45 {
46 return GetUnpaddedTensorStrides(m_TensorInfo);
47 }
48
49 TensorShape GetShape() const override
50 {
51 return m_TensorInfo.GetShape();
52 }
53
54 const TensorInfo& GetTensorInfo() const
55 {
56 return m_TensorInfo;
57 }
58
59 MemorySourceFlags GetImportFlags() const override
60 {
61 return m_ImportFlags;
62 }
63
64 bool Import(void* memory, MemorySource source) override;
65 bool CanBeImported(void* memory, MemorySource source) override;
66
67private:
68 // Only used for testing
69 void CopyOutTo(void*) const override;
70 void CopyInFrom(const void*) override;
71
72 void* GetPointer() const;
73
74 MockTensorHandle(const MockTensorHandle& other) = delete; // noncopyable
75 MockTensorHandle& operator=(const MockTensorHandle& other) = delete; //noncopyable
76
77 TensorInfo m_TensorInfo;
78
79 std::shared_ptr<MockMemoryManager> m_MemoryManager;
80 MockMemoryManager::Pool* m_Pool;
81 mutable void* m_UnmanagedMemory;
82 MemorySourceFlags m_ImportFlags;
83 bool m_Imported;
84 bool m_IsImportEnabled;
85};
86
87} // namespace armnn