blob: 3504f53bc75f01c58f6cdfa86bc0a8c4b71c6922 [file] [log] [blame]
Matthew Bentham7c1603a2019-06-21 17:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include <reference/RefTensorHandle.hpp>
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +01006#include <reference/RefTensorHandleFactory.hpp>
Matthew Bentham7c1603a2019-06-21 17:22:23 +01007
Sadik Armagan1625efc2021-06-10 18:24:34 +01008#include <doctest/doctest.h>
Matthew Bentham7c1603a2019-06-21 17:22:23 +01009
Sadik Armagan1625efc2021-06-10 18:24:34 +010010TEST_SUITE("RefTensorHandleTests")
11{
Matthew Bentham7c1603a2019-06-21 17:22:23 +010012using namespace armnn;
13
Sadik Armagan1625efc2021-06-10 18:24:34 +010014TEST_CASE("AcquireAndRelease")
Matthew Bentham7c1603a2019-06-21 17:22:23 +010015{
16 std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
17
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +010018 TensorInfo info({ 1, 1, 1, 1 }, DataType::Float32);
Matthew Bentham7c1603a2019-06-21 17:22:23 +010019 RefTensorHandle handle(info, memoryManager);
20
21 handle.Manage();
22 handle.Allocate();
23
24 memoryManager->Acquire();
25 {
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +010026 float* buffer = reinterpret_cast<float*>(handle.Map());
Matthew Bentham7c1603a2019-06-21 17:22:23 +010027
Sadik Armagan1625efc2021-06-10 18:24:34 +010028 CHECK(buffer != nullptr); // Yields a valid pointer
Matthew Bentham7c1603a2019-06-21 17:22:23 +010029
30 buffer[0] = 2.5f;
31
Sadik Armagan1625efc2021-06-10 18:24:34 +010032 CHECK(buffer[0] == 2.5f); // Memory is writable and readable
Matthew Bentham7c1603a2019-06-21 17:22:23 +010033
34 }
35 memoryManager->Release();
36
37 memoryManager->Acquire();
38 {
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +010039 float* buffer = reinterpret_cast<float*>(handle.Map());
Matthew Bentham7c1603a2019-06-21 17:22:23 +010040
Sadik Armagan1625efc2021-06-10 18:24:34 +010041 CHECK(buffer != nullptr); // Yields a valid pointer
Matthew Bentham7c1603a2019-06-21 17:22:23 +010042
43 buffer[0] = 3.5f;
44
Sadik Armagan1625efc2021-06-10 18:24:34 +010045 CHECK(buffer[0] == 3.5f); // Memory is writable and readable
Matthew Bentham7c1603a2019-06-21 17:22:23 +010046 }
47 memoryManager->Release();
48}
49
Sadik Armagan1625efc2021-06-10 18:24:34 +010050TEST_CASE("RefTensorHandleFactoryMemoryManaged")
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +010051{
52 std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
53 RefTensorHandleFactory handleFactory(memoryManager);
54 TensorInfo info({ 1, 1, 2, 1 }, DataType::Float32);
55
56 // create TensorHandle with memory managed
57 auto handle = handleFactory.CreateTensorHandle(info, true);
58 handle->Manage();
59 handle->Allocate();
60
61 memoryManager->Acquire();
62 {
63 float* buffer = reinterpret_cast<float*>(handle->Map());
Sadik Armagan1625efc2021-06-10 18:24:34 +010064 CHECK(buffer != nullptr); // Yields a valid pointer
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +010065 buffer[0] = 1.5f;
66 buffer[1] = 2.5f;
Sadik Armagan1625efc2021-06-10 18:24:34 +010067 CHECK(buffer[0] == 1.5f); // Memory is writable and readable
68 CHECK(buffer[1] == 2.5f); // Memory is writable and readable
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +010069 }
70 memoryManager->Release();
71
72 memoryManager->Acquire();
73 {
74 float* buffer = reinterpret_cast<float*>(handle->Map());
Sadik Armagan1625efc2021-06-10 18:24:34 +010075 CHECK(buffer != nullptr); // Yields a valid pointer
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +010076 buffer[0] = 3.5f;
77 buffer[1] = 4.5f;
Sadik Armagan1625efc2021-06-10 18:24:34 +010078 CHECK(buffer[0] == 3.5f); // Memory is writable and readable
79 CHECK(buffer[1] == 4.5f); // Memory is writable and readable
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +010080 }
81 memoryManager->Release();
82
83 float testPtr[2] = { 2.5f, 5.5f };
84 // Cannot import as import is disabled
Sadik Armagan1625efc2021-06-10 18:24:34 +010085 CHECK(!handle->Import(static_cast<void*>(testPtr), MemorySource::Malloc));
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +010086}
87
Sadik Armagan1625efc2021-06-10 18:24:34 +010088TEST_CASE("RefTensorHandleFactoryImport")
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +010089{
90 std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
91 RefTensorHandleFactory handleFactory(memoryManager);
92 TensorInfo info({ 1, 1, 2, 1 }, DataType::Float32);
93
94 // create TensorHandle without memory managed
95 auto handle = handleFactory.CreateTensorHandle(info, false);
96 handle->Manage();
97 handle->Allocate();
98 memoryManager->Acquire();
99
100 // No buffer allocated when import is enabled
Sadik Armagan1625efc2021-06-10 18:24:34 +0100101 CHECK_THROWS_AS(handle->Map(), armnn::NullPointerException);
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +0100102
103 float testPtr[2] = { 2.5f, 5.5f };
104 // Correctly import
Sadik Armagan1625efc2021-06-10 18:24:34 +0100105 CHECK(handle->Import(static_cast<void*>(testPtr), MemorySource::Malloc));
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +0100106 float* buffer = reinterpret_cast<float*>(handle->Map());
Sadik Armagan1625efc2021-06-10 18:24:34 +0100107 CHECK(buffer != nullptr); // Yields a valid pointer after import
108 CHECK(buffer == testPtr); // buffer is pointing to testPtr
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +0100109 // Memory is writable and readable with correct value
Sadik Armagan1625efc2021-06-10 18:24:34 +0100110 CHECK(buffer[0] == 2.5f);
111 CHECK(buffer[1] == 5.5f);
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +0100112 buffer[0] = 3.5f;
113 buffer[1] = 10.0f;
Sadik Armagan1625efc2021-06-10 18:24:34 +0100114 CHECK(buffer[0] == 3.5f);
115 CHECK(buffer[1] == 10.0f);
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +0100116 memoryManager->Release();
117}
118
Sadik Armagan1625efc2021-06-10 18:24:34 +0100119TEST_CASE("RefTensorHandleImport")
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +0100120{
121 TensorInfo info({ 1, 1, 2, 1 }, DataType::Float32);
122 RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
123
124 handle.Manage();
125 handle.Allocate();
126
127 // No buffer allocated when import is enabled
Sadik Armagan1625efc2021-06-10 18:24:34 +0100128 CHECK_THROWS_AS(handle.Map(), armnn::NullPointerException);
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +0100129
130 float testPtr[2] = { 2.5f, 5.5f };
131 // Correctly import
Sadik Armagan1625efc2021-06-10 18:24:34 +0100132 CHECK(handle.Import(static_cast<void*>(testPtr), MemorySource::Malloc));
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +0100133 float* buffer = reinterpret_cast<float*>(handle.Map());
Sadik Armagan1625efc2021-06-10 18:24:34 +0100134 CHECK(buffer != nullptr); // Yields a valid pointer after import
135 CHECK(buffer == testPtr); // buffer is pointing to testPtr
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +0100136 // Memory is writable and readable with correct value
Sadik Armagan1625efc2021-06-10 18:24:34 +0100137 CHECK(buffer[0] == 2.5f);
138 CHECK(buffer[1] == 5.5f);
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +0100139 buffer[0] = 3.5f;
140 buffer[1] = 10.0f;
Sadik Armagan1625efc2021-06-10 18:24:34 +0100141 CHECK(buffer[0] == 3.5f);
142 CHECK(buffer[1] == 10.0f);
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +0100143}
144
Sadik Armagan1625efc2021-06-10 18:24:34 +0100145TEST_CASE("RefTensorHandleGetCapabilities")
Narumol Prangnawaratd6568772020-07-22 12:46:51 +0100146{
147 std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
148 RefTensorHandleFactory handleFactory(memoryManager);
149
150 // Builds up the structure of the network.
151 INetworkPtr net(INetwork::Create());
152 IConnectableLayer* input = net->AddInputLayer(0);
153 IConnectableLayer* output = net->AddOutputLayer(0);
154 input->GetOutputSlot(0).Connect(output->GetInputSlot(0));
155
156 std::vector<Capability> capabilities = handleFactory.GetCapabilities(input,
157 output,
158 CapabilityClass::PaddingRequired);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100159 CHECK(capabilities.empty());
Narumol Prangnawaratd6568772020-07-22 12:46:51 +0100160}
161
Sadik Armagan1625efc2021-06-10 18:24:34 +0100162TEST_CASE("RefTensorHandleSupportsInPlaceComputation")
Sadik Armaganab3bd4d2020-08-25 11:48:00 +0100163{
164 std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
165 RefTensorHandleFactory handleFactory(memoryManager);
166
167 // RefTensorHandleFactory does not support InPlaceComputation
168 ARMNN_ASSERT(!(handleFactory.SupportsInPlaceComputation()));
169}
170
Sadik Armagan1625efc2021-06-10 18:24:34 +0100171TEST_CASE("TestManagedConstTensorHandle")
Francis Murtagh4af56162021-04-20 16:37:55 +0100172{
173 // Initialize arguments
174 void* mem = nullptr;
175 TensorInfo info;
176
James Conroy1f58f032021-04-27 17:13:27 +0100177 // Use PassthroughTensor as others are abstract
178 auto passThroughHandle = std::make_shared<PassthroughTensorHandle>(info, mem);
Francis Murtagh4af56162021-04-20 16:37:55 +0100179
180 // Test managed handle is initialized with m_Mapped unset and once Map() called its set
181 ManagedConstTensorHandle managedHandle(passThroughHandle);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100182 CHECK(!managedHandle.IsMapped());
Francis Murtagh4af56162021-04-20 16:37:55 +0100183 managedHandle.Map();
Sadik Armagan1625efc2021-06-10 18:24:34 +0100184 CHECK(managedHandle.IsMapped());
Francis Murtagh4af56162021-04-20 16:37:55 +0100185
186 // Test it can then be unmapped
187 managedHandle.Unmap();
Sadik Armagan1625efc2021-06-10 18:24:34 +0100188 CHECK(!managedHandle.IsMapped());
Francis Murtagh4af56162021-04-20 16:37:55 +0100189
190 // Test member function
Sadik Armagan1625efc2021-06-10 18:24:34 +0100191 CHECK(managedHandle.GetTensorInfo() == info);
Francis Murtagh4af56162021-04-20 16:37:55 +0100192
193 // Test that nullptr tensor handle doesn't get mapped
194 ManagedConstTensorHandle managedHandleNull(nullptr);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100195 CHECK(!managedHandleNull.IsMapped());
196 CHECK_THROWS_AS(managedHandleNull.Map(), armnn::Exception);
197 CHECK(!managedHandleNull.IsMapped());
Francis Murtagh4af56162021-04-20 16:37:55 +0100198
199 // Check Unmap() when m_Mapped already false
200 managedHandleNull.Unmap();
Sadik Armagan1625efc2021-06-10 18:24:34 +0100201 CHECK(!managedHandleNull.IsMapped());
Francis Murtagh4af56162021-04-20 16:37:55 +0100202}
203
Ferran Balaguerc33882d2019-08-21 13:59:13 +0100204#if !defined(__ANDROID__)
205// Only run these tests on non Android platforms
Sadik Armagan1625efc2021-06-10 18:24:34 +0100206TEST_CASE("CheckSourceType")
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100207{
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100208 TensorInfo info({1}, DataType::Float32);
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +0100209 RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100210
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100211 int* testPtr = new int(4);
212
213 // Not supported
Sadik Armagan1625efc2021-06-10 18:24:34 +0100214 CHECK(!handle.Import(static_cast<void *>(testPtr), MemorySource::DmaBuf));
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100215
216 // Not supported
Sadik Armagan1625efc2021-06-10 18:24:34 +0100217 CHECK(!handle.Import(static_cast<void *>(testPtr), MemorySource::DmaBufProtected));
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100218
219 // Supported
Sadik Armagan1625efc2021-06-10 18:24:34 +0100220 CHECK(handle.Import(static_cast<void *>(testPtr), MemorySource::Malloc));
Ferran Balaguer1cd451c2019-08-22 14:09:44 +0100221
222 delete testPtr;
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100223}
224
Sadik Armagan1625efc2021-06-10 18:24:34 +0100225TEST_CASE("ReusePointer")
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100226{
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100227 TensorInfo info({1}, DataType::Float32);
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +0100228 RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100229
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100230 int* testPtr = new int(4);
231
232 handle.Import(static_cast<void *>(testPtr), MemorySource::Malloc);
233
234 // Reusing previously Imported pointer
Sadik Armagan1625efc2021-06-10 18:24:34 +0100235 CHECK(handle.Import(static_cast<void *>(testPtr), MemorySource::Malloc));
Ferran Balaguer1cd451c2019-08-22 14:09:44 +0100236
237 delete testPtr;
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100238}
239
Sadik Armagan1625efc2021-06-10 18:24:34 +0100240TEST_CASE("MisalignedPointer")
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100241{
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100242 TensorInfo info({2}, DataType::Float32);
Narumol Prangnawarat3b90af62020-06-26 11:00:21 +0100243 RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100244
Aron Virginas-Tard9f7c8b2019-09-13 13:37:03 +0100245 // Allocate a 2 int array
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100246 int* testPtr = new int[2];
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100247
Aron Virginas-Tard9f7c8b2019-09-13 13:37:03 +0100248 // Increment pointer by 1 byte
249 void* misalignedPtr = static_cast<void*>(reinterpret_cast<char*>(testPtr) + 1);
250
Sadik Armagan1625efc2021-06-10 18:24:34 +0100251 CHECK(!handle.Import(misalignedPtr, MemorySource::Malloc));
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100252
253 delete[] testPtr;
254}
255
Nikhil Raj53e06592022-01-05 16:04:08 +0000256TEST_CASE("CheckCanBeImported")
257{
258 TensorInfo info({1}, DataType::Float32);
259 RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
260
261 int* testPtr = new int(4);
262
263 // Not supported
264 CHECK(!handle.CanBeImported(static_cast<void *>(testPtr), MemorySource::DmaBuf));
265
266 // Supported
267 CHECK(handle.CanBeImported(static_cast<void *>(testPtr), MemorySource::Malloc));
268
269 delete testPtr;
270
271}
272
273TEST_CASE("MisalignedCanBeImported")
274{
275 TensorInfo info({2}, DataType::Float32);
276 RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
277
278 // Allocate a 2 int array
279 int* testPtr = new int[2];
280
281 // Increment pointer by 1 byte
282 void* misalignedPtr = static_cast<void*>(reinterpret_cast<char*>(testPtr) + 1);
283
284 CHECK(!handle.Import(misalignedPtr, MemorySource::Malloc));
285
286 delete[] testPtr;
287}
288
Ferran Balaguerc33882d2019-08-21 13:59:13 +0100289#endif
290
Sadik Armagan1625efc2021-06-10 18:24:34 +0100291}