blob: 249b915ce12b27c6153bd78392c55b2713859b47 [file] [log] [blame]
David Monahan8a570462023-11-22 13:24:25 +00001//
2// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include "GpuFsaTensorHandle.hpp"
6
7namespace armnn
8{
9GpuFsaTensorHandle::GpuFsaTensorHandle(const TensorInfo& tensorInfo,
10 std::shared_ptr<GpuFsaMemoryManager>& memoryManager)
11 : m_TensorInfo(tensorInfo)
12 , m_MemoryManager(memoryManager)
13 , m_Pool(nullptr)
14 , m_UnmanagedMemory(nullptr)
15 , m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined))
16 , m_Imported(false)
17 , m_IsImportEnabled(false)
18{}
19
20GpuFsaTensorHandle::GpuFsaTensorHandle(const TensorInfo& tensorInfo,
21 MemorySourceFlags importFlags)
22 : m_TensorInfo(tensorInfo)
23 , m_Pool(nullptr)
24 , m_UnmanagedMemory(nullptr)
25 , m_ImportFlags(importFlags)
26 , m_Imported(false)
27 , m_IsImportEnabled(true)
28{}
29
30GpuFsaTensorHandle::~GpuFsaTensorHandle()
31{
32 if (!m_Pool)
33 {
34 // unmanaged
35 if (!m_Imported)
36 {
37 ::operator delete(m_UnmanagedMemory);
38 }
39 }
40}
41
42void GpuFsaTensorHandle::Manage()
43{
44 if (!m_IsImportEnabled)
45 {
46 if (m_Pool == nullptr)
47 {
48 throw MemoryValidationException("GpuFsaTensorHandle::Manage() called twice");
49 }
50 if (m_UnmanagedMemory == nullptr)
51 {
52 throw MemoryValidationException("GpuFsaTensorHandle::Manage() called after Allocate()");
53 }
54
55 m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes());
56 }
57}
58
59void GpuFsaTensorHandle::Allocate()
60{
61 // If import is enabled, do not allocate the tensor
62 if (!m_IsImportEnabled)
63 {
64
65 if (!m_UnmanagedMemory)
66 {
67 if (!m_Pool)
68 {
69 // unmanaged
70 m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
71 }
72 else
73 {
74 m_MemoryManager->Allocate(m_Pool);
75 }
76 }
77 else
78 {
79 throw InvalidArgumentException("GpuFsaTensorHandle::Allocate Trying to allocate a GpuFsaTensorHandle"
80 "that already has allocated memory.");
81 }
82 }
83}
84
85const void* GpuFsaTensorHandle::Map(bool /*unused*/) const
86{
87 return GetPointer();
88}
89
90void* GpuFsaTensorHandle::GetPointer() const
91{
92 if (m_UnmanagedMemory)
93 {
94 return m_UnmanagedMemory;
95 }
96 else if (m_Pool)
97 {
98 return m_MemoryManager->GetPointer(m_Pool);
99 }
100 else
101 {
102 throw NullPointerException("GpuFsaTensorHandle::GetPointer called on unmanaged, unallocated tensor handle");
103 }
104}
105
106void GpuFsaTensorHandle::CopyOutTo(void* dest) const
107{
108 const void *src = GetPointer();
109 if (src == nullptr)
110 {
111 throw MemoryValidationException("GpuFsaTensorhandle: CopyOutTo: Invalid memory src pointer");
112 }
113 memcpy(dest, src, m_TensorInfo.GetNumBytes());
114}
115
116void GpuFsaTensorHandle::CopyInFrom(const void* src)
117{
118 void *dest = GetPointer();
119 if (dest == nullptr)
120 {
121 throw MemoryValidationException("GpuFsaTensorhandle: CopyInFrom: Invalid memory dest pointer");
122 }
123 memcpy(dest, src, m_TensorInfo.GetNumBytes());
124}
125
126bool GpuFsaTensorHandle::Import(void* memory, MemorySource source)
127{
128 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
129 {
130 if (m_IsImportEnabled && source == MemorySource::Malloc)
131 {
132 // Check memory alignment
133 if(!CanBeImported(memory, source))
134 {
135 if (m_Imported)
136 {
137 m_Imported = false;
138 m_UnmanagedMemory = nullptr;
139 }
140 return false;
141 }
142
143 // m_UnmanagedMemory not yet allocated.
144 if (!m_Imported && !m_UnmanagedMemory)
145 {
146 m_UnmanagedMemory = memory;
147 m_Imported = true;
148 return true;
149 }
150
151 // m_UnmanagedMemory initially allocated with Allocate().
152 if (!m_Imported && m_UnmanagedMemory)
153 {
154 return false;
155 }
156
157 // m_UnmanagedMemory previously imported.
158 if (m_Imported)
159 {
160 m_UnmanagedMemory = memory;
161 return true;
162 }
163 }
164 }
165
166 return false;
167}
168
169bool GpuFsaTensorHandle::CanBeImported(void* memory, MemorySource source)
170{
171 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
172 {
173 if (m_IsImportEnabled && source == MemorySource::Malloc)
174 {
175 uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
176 if (reinterpret_cast<uintptr_t>(memory) % alignment)
177 {
178 return false;
179 }
180 return true;
181 }
182 }
183 return false;
184}
185
186
187
188}