blob: 2a44ed6ada22450a05d1970daecbf1868b93ae69 [file] [log] [blame]
Narumol Prangnawarat867eba52020-02-03 12:29:56 +00001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include <backendsCommon/CpuTensorHandle.hpp>
8
9#include "SampleMemoryManager.hpp"
10
11namespace armnn
12{
13
14// An implementation of ITensorHandle with simple "bump the pointer" memory-management behaviour
15class SampleTensorHandle : public ITensorHandle
16{
17public:
18 SampleTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<SampleMemoryManager> &memoryManager);
19
Narumol Prangnawarat0739fee2020-08-11 11:24:25 +010020 SampleTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags);
Narumol Prangnawarat867eba52020-02-03 12:29:56 +000021
22 ~SampleTensorHandle();
23
24 virtual void Manage() override;
25
26 virtual void Allocate() override;
27
28 virtual ITensorHandle* GetParent() const override
29 {
30 return nullptr;
31 }
32
33 virtual const void* Map(bool /* blocking = true */) const override;
34 using ITensorHandle::Map;
35
36 virtual void Unmap() const override
37 {}
38
39 TensorShape GetStrides() const override
40 {
41 return GetUnpaddedTensorStrides(m_TensorInfo);
42 }
43
44 TensorShape GetShape() const override
45 {
46 return m_TensorInfo.GetShape();
47 }
48
49 const TensorInfo& GetTensorInfo() const
50 {
51 return m_TensorInfo;
52 }
53
54 virtual MemorySourceFlags GetImportFlags() const override
55 {
56 return m_ImportFlags;
57 }
58
59 virtual bool Import(void* memory, MemorySource source) override;
60
61private:
Narumol Prangnawarat0739fee2020-08-11 11:24:25 +010062 // Only used for testing
63 void CopyOutTo(void*) const override;
64 void CopyInFrom(const void*) override;
65
Narumol Prangnawarat867eba52020-02-03 12:29:56 +000066 void* GetPointer() const;
67
68 SampleTensorHandle(const SampleTensorHandle& other) = delete; // noncopyable
69 SampleTensorHandle& operator=(const SampleTensorHandle& other) = delete; //noncopyable
70
71 TensorInfo m_TensorInfo;
72
73 std::shared_ptr<SampleMemoryManager> m_MemoryManager;
74 SampleMemoryManager::Pool* m_Pool;
75 mutable void *m_UnmanagedMemory;
76 MemorySourceFlags m_ImportFlags;
77 bool m_Imported;
78};
79
80}