blob: be4d5a8d9220d746bada5d89403838e134b5e642 [file] [log] [blame]
Colm Donelan17948b52022-02-01 23:37:04 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "armnnTestUtils/MockTensorHandle.hpp"
7
8namespace armnn
9{
10
11MockTensorHandle::MockTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<MockMemoryManager>& memoryManager)
12 : m_TensorInfo(tensorInfo)
13 , m_MemoryManager(memoryManager)
14 , m_Pool(nullptr)
15 , m_UnmanagedMemory(nullptr)
16 , m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined))
17 , m_Imported(false)
18 , m_IsImportEnabled(false)
19{}
20
21MockTensorHandle::MockTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags)
22 : m_TensorInfo(tensorInfo)
23 , m_Pool(nullptr)
24 , m_UnmanagedMemory(nullptr)
25 , m_ImportFlags(importFlags)
26 , m_Imported(false)
27 , m_IsImportEnabled(true)
28{}
29
30MockTensorHandle::~MockTensorHandle()
31{
32 if (!m_Pool)
33 {
34 // unmanaged
35 if (!m_Imported)
36 {
37 ::operator delete(m_UnmanagedMemory);
38 }
39 }
40}
41
42void MockTensorHandle::Manage()
43{
44 if (!m_IsImportEnabled)
45 {
46 ARMNN_ASSERT_MSG(!m_Pool, "MockTensorHandle::Manage() called twice");
47 ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "MockTensorHandle::Manage() called after Allocate()");
48
49 m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes());
50 }
51}
52
53void MockTensorHandle::Allocate()
54{
55 // If import is enabled, do not allocate the tensor
56 if (!m_IsImportEnabled)
57 {
58
59 if (!m_UnmanagedMemory)
60 {
61 if (!m_Pool)
62 {
63 // unmanaged
64 m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
65 }
66 else
67 {
68 m_MemoryManager->Allocate(m_Pool);
69 }
70 }
71 else
72 {
73 throw InvalidArgumentException("MockTensorHandle::Allocate Trying to allocate a MockTensorHandle"
74 "that already has allocated memory.");
75 }
76 }
77}
78
79const void* MockTensorHandle::Map(bool /*unused*/) const
80{
81 return GetPointer();
82}
83
84void* MockTensorHandle::GetPointer() const
85{
86 if (m_UnmanagedMemory)
87 {
88 return m_UnmanagedMemory;
89 }
90 else if (m_Pool)
91 {
92 return m_MemoryManager->GetPointer(m_Pool);
93 }
94 else
95 {
96 throw NullPointerException("MockTensorHandle::GetPointer called on unmanaged, unallocated tensor handle");
97 }
98}
99
100void MockTensorHandle::CopyOutTo(void* dest) const
101{
102 const void* src = GetPointer();
103 ARMNN_ASSERT(src);
104 memcpy(dest, src, m_TensorInfo.GetNumBytes());
105}
106
107void MockTensorHandle::CopyInFrom(const void* src)
108{
109 void* dest = GetPointer();
110 ARMNN_ASSERT(dest);
111 memcpy(dest, src, m_TensorInfo.GetNumBytes());
112}
113
114bool MockTensorHandle::Import(void* memory, MemorySource source)
115{
116 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
117 {
118 if (m_IsImportEnabled && source == MemorySource::Malloc)
119 {
120 // Check memory alignment
121 if (!CanBeImported(memory, source))
122 {
123 if (m_Imported)
124 {
125 m_Imported = false;
126 m_UnmanagedMemory = nullptr;
127 }
128
129 return false;
130 }
131
132 // m_UnmanagedMemory not yet allocated.
133 if (!m_Imported && !m_UnmanagedMemory)
134 {
135 m_UnmanagedMemory = memory;
136 m_Imported = true;
137 return true;
138 }
139
140 // m_UnmanagedMemory initially allocated with Allocate().
141 if (!m_Imported && m_UnmanagedMemory)
142 {
143 return false;
144 }
145
146 // m_UnmanagedMemory previously imported.
147 if (m_Imported)
148 {
149 m_UnmanagedMemory = memory;
150 return true;
151 }
152 }
153 }
154
155 return false;
156}
157
158bool MockTensorHandle::CanBeImported(void* memory, MemorySource source)
159{
160 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
161 {
162 if (m_IsImportEnabled && source == MemorySource::Malloc)
163 {
164 uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
165 if (reinterpret_cast<uintptr_t>(memory) % alignment)
166 {
167 return false;
168 }
169
170 return true;
171 }
172 }
173 return false;
174}
175
176} // namespace armnn