blob: aaffc8ab6c11c211d2f2ec9562a20edeb400daea [file] [log] [blame]
Francis Murtagh9270d9e2022-08-12 13:54:17 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include "TosaRefTensorHandle.hpp"
6
7namespace armnn
8{
9
10TosaRefTensorHandle::TosaRefTensorHandle(const TensorInfo& tensorInfo,
11 std::shared_ptr<TosaRefMemoryManager>& memoryManager)
12 : m_TensorInfo(tensorInfo)
13 , m_MemoryManager(memoryManager)
14 , m_Pool(nullptr)
15 , m_UnmanagedMemory(nullptr)
16 , m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined))
17 , m_Imported(false)
18 , m_IsImportEnabled(false)
19{}
20
21TosaRefTensorHandle::TosaRefTensorHandle(const TensorInfo& tensorInfo,
22 MemorySourceFlags importFlags)
23 : m_TensorInfo(tensorInfo)
24 , m_Pool(nullptr)
25 , m_UnmanagedMemory(nullptr)
26 , m_ImportFlags(importFlags)
27 , m_Imported(false)
28 , m_IsImportEnabled(true)
29{}
30
31TosaRefTensorHandle::~TosaRefTensorHandle()
32{
33 if (!m_Pool)
34 {
35 // unmanaged
36 if (!m_Imported)
37 {
38 ::operator delete(m_UnmanagedMemory);
39 }
40 }
41}
42
43void TosaRefTensorHandle::Manage()
44{
45 if (!m_IsImportEnabled)
46 {
47 ARMNN_ASSERT_MSG(!m_Pool, "TosaRefTensorHandle::Manage() called twice");
48 ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "TosaRefTensorHandle::Manage() called after Allocate()");
49
50 m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes());
51 }
52}
53
54void TosaRefTensorHandle::Allocate()
55{
56 // If import is enabled, do not allocate the tensor
57 if (!m_IsImportEnabled)
58 {
59
60 if (!m_UnmanagedMemory)
61 {
62 if (!m_Pool)
63 {
64 // unmanaged
65 m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
66 }
67 else
68 {
69 m_MemoryManager->Allocate(m_Pool);
70 }
71 }
72 else
73 {
74 throw InvalidArgumentException("TosaRefTensorHandle::Allocate Trying to allocate a TosaRefTensorHandle"
75 "that already has allocated memory.");
76 }
77 }
78}
79
80const void* TosaRefTensorHandle::Map(bool /*unused*/) const
81{
82 return GetPointer();
83}
84
85void* TosaRefTensorHandle::GetPointer() const
86{
87 if (m_UnmanagedMemory)
88 {
89 return m_UnmanagedMemory;
90 }
91 else if (m_Pool)
92 {
93 return m_MemoryManager->GetPointer(m_Pool);
94 }
95 else
96 {
97 throw NullPointerException("TosaRefTensorHandle::GetPointer called on unmanaged, unallocated tensor handle");
98 }
99}
100
101void TosaRefTensorHandle::CopyOutTo(void* dest) const
102{
David Monahan6a1d5062023-08-29 09:10:50 +0100103 const void* src = GetPointer();
104 if (src == nullptr)
105 {
106 throw NullPointerException("TosaRefTensorHandle::CopyOutTo called with a null src pointer");
107 }
108 if (dest == nullptr)
109 {
110 throw NullPointerException("TosaRefTensorHandle::CopyOutTo called with a null dest pointer");
111 }
112 memcpy(dest, src, GetTensorInfo().GetNumBytes());
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100113}
114
115void TosaRefTensorHandle::CopyInFrom(const void* src)
116{
David Monahan6a1d5062023-08-29 09:10:50 +0100117 void* dest = GetPointer();
118 if (dest == nullptr)
119 {
120 throw NullPointerException("TosaRefTensorHandle::CopyInFrom called with a null dest pointer");
121 }
122 if (src == nullptr)
123 {
124 throw NullPointerException("TosaRefTensorHandle::CopyInFrom called with a null src pointer");
125 }
126 memcpy(dest, src, GetTensorInfo().GetNumBytes());
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100127}
128
129bool TosaRefTensorHandle::Import(void* memory, MemorySource source)
130{
131 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
132 {
133 if (m_IsImportEnabled && source == MemorySource::Malloc)
134 {
135 // Check memory alignment
136 if(!CanBeImported(memory, source))
137 {
138 if (m_Imported)
139 {
140 m_Imported = false;
141 m_UnmanagedMemory = nullptr;
142 }
143 return false;
144 }
145
146 // m_UnmanagedMemory not yet allocated.
147 if (!m_Imported && !m_UnmanagedMemory)
148 {
149 m_UnmanagedMemory = memory;
150 m_Imported = true;
151 return true;
152 }
153
154 // m_UnmanagedMemory initially allocated with Allocate().
155 if (!m_Imported && m_UnmanagedMemory)
156 {
157 return false;
158 }
159
160 // m_UnmanagedMemory previously imported.
161 if (m_Imported)
162 {
163 m_UnmanagedMemory = memory;
164 return true;
165 }
166 }
167 }
168
169 return false;
170}
171
172bool TosaRefTensorHandle::CanBeImported(void* memory, MemorySource source)
173{
174 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
175 {
176 if (m_IsImportEnabled && source == MemorySource::Malloc)
177 {
178 uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
179 if (reinterpret_cast<uintptr_t>(memory) % alignment)
180 {
181 return false;
182 }
183 return true;
184 }
185 }
186 return false;
187}
188
189}