blob: 48f8cf44fa5db29027cf062d2f2f28247cd3c42a [file] [log] [blame]
Narumol Prangnawarat867eba52020-02-03 12:29:56 +00001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "SampleTensorHandle.hpp"
7
8namespace armnn
9{
10
11SampleTensorHandle::SampleTensorHandle(const TensorInfo &tensorInfo,
12 std::shared_ptr<SampleMemoryManager> &memoryManager)
13 : m_TensorInfo(tensorInfo),
14 m_MemoryManager(memoryManager),
15 m_Pool(nullptr),
16 m_UnmanagedMemory(nullptr),
17 m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
18 m_Imported(false)
19{
20
21}
22
23SampleTensorHandle::SampleTensorHandle(const TensorInfo& tensorInfo,
24 std::shared_ptr<SampleMemoryManager> &memoryManager,
25 MemorySourceFlags importFlags)
26 : m_TensorInfo(tensorInfo),
27 m_MemoryManager(memoryManager),
28 m_Pool(nullptr),
29 m_UnmanagedMemory(nullptr),
30 m_ImportFlags(importFlags),
31 m_Imported(false)
32{
33
34}
35
36SampleTensorHandle::~SampleTensorHandle()
37{
38 if (!m_Pool)
39 {
40 // unmanaged
41 if (!m_Imported)
42 {
43 ::operator delete(m_UnmanagedMemory);
44 }
45 }
46}
47
48void SampleTensorHandle::Manage()
49{
50 m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes());
51}
52
53void SampleTensorHandle::Allocate()
54{
55 if (!m_UnmanagedMemory)
56 {
57 if (!m_Pool)
58 {
59 // unmanaged
60 m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
61 }
62 else
63 {
64 m_MemoryManager->Allocate(m_Pool);
65 }
66 }
67 else
68 {
69 throw InvalidArgumentException("SampleTensorHandle::Allocate Trying to allocate a SampleTensorHandle"
70 "that already has allocated memory.");
71 }
72}
73
74const void* SampleTensorHandle::Map(bool /*unused*/) const
75{
76 return GetPointer();
77}
78
79void* SampleTensorHandle::GetPointer() const
80{
81 if (m_UnmanagedMemory)
82 {
83 return m_UnmanagedMemory;
84 }
85 else
86 {
87 return m_MemoryManager->GetPointer(m_Pool);
88 }
89}
90
91bool SampleTensorHandle::Import(void* memory, MemorySource source)
92{
93
94 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
95 {
96 if (source == MemorySource::Malloc)
97 {
98 // Check memory alignment
99 constexpr uintptr_t alignment = sizeof(size_t);
100 if (reinterpret_cast<uintptr_t>(memory) % alignment)
101 {
102 if (m_Imported)
103 {
104 m_Imported = false;
105 m_UnmanagedMemory = nullptr;
106 }
107
108 return false;
109 }
110
111 // m_UnmanagedMemory not yet allocated.
112 if (!m_Imported && !m_UnmanagedMemory)
113 {
114 m_UnmanagedMemory = memory;
115 m_Imported = true;
116 return true;
117 }
118
119 // m_UnmanagedMemory initially allocated with Allocate().
120 if (!m_Imported && m_UnmanagedMemory)
121 {
122 return false;
123 }
124
125 // m_UnmanagedMemory previously imported.
126 if (m_Imported)
127 {
128 m_UnmanagedMemory = memory;
129 return true;
130 }
131 }
132 }
133
134 return false;
135}
136
137}