blob: 196c0fb41281d58afedb2cf02abd96f5d011392f [file] [log] [blame]
Francis Murtaghe8d7ccb2021-10-14 17:30:24 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <armnn/backends/ICustomAllocator.hpp>
7#include <armnn/Descriptors.hpp>
8#include <armnn/Exceptions.hpp>
9#include <armnn/IRuntime.hpp>
10#include <backendsCommon/TensorHandle.hpp>
11// Requires the OpenCl backend to be included (GpuAcc)
12#include <cl/ClBackend.hpp>
13#include <doctest/doctest.h>
14#include <backendsCommon/DefaultAllocator.hpp>
15#include <backendsCommon/test/MockBackend.hpp>
16
17using namespace armnn;
18
19
20namespace
21{
22
23TEST_SUITE("DefaultAllocatorTests")
24{
25
26TEST_CASE("DefaultAllocatorTest")
27{
28 float number = 3;
29
30 TensorInfo inputTensorInfo(TensorShape({1, 1}), DataType::Float32);
31
32 // Create ArmNN runtime
33 IRuntime::CreationOptions options; // default options
34 auto customAllocator = std::make_shared<DefaultAllocator>();
35 options.m_CustomAllocatorMap = {{"GpuAcc", std::move(customAllocator)}};
36 IRuntimePtr run = IRuntime::Create(options);
37
38 // Creates structures for input & output
39 unsigned int numElements = inputTensorInfo.GetNumElements();
40 size_t totalBytes = numElements * sizeof(float);
41
42 void* alignedInputPtr = options.m_CustomAllocatorMap["GpuAcc"]->allocate(totalBytes, 0);
43
44 auto* inputPtr = reinterpret_cast<float*>(alignedInputPtr);
45 std::fill_n(inputPtr, numElements, number);
46 CHECK(inputPtr[0] == 3);
47
48 auto& backendRegistry = armnn::BackendRegistryInstance();
49 backendRegistry.DeregisterAllocator(ClBackend::GetIdStatic());
50}
51
52TEST_CASE("DefaultAllocatorTestMulti")
53{
54 float number = 3;
55
56 TensorInfo inputTensorInfo(TensorShape({2, 1}), DataType::Float32);
57
58 // Create ArmNN runtime
59 IRuntime::CreationOptions options; // default options
60 auto customAllocator = std::make_shared<DefaultAllocator>();
61 options.m_CustomAllocatorMap = {{"GpuAcc", std::move(customAllocator)}};
62 IRuntimePtr run = IRuntime::Create(options);
63
64 // Creates structures for input & output
65 unsigned int numElements = inputTensorInfo.GetNumElements();
66 size_t totalBytes = numElements * sizeof(float);
67
68 void* alignedInputPtr = options.m_CustomAllocatorMap["GpuAcc"]->allocate(totalBytes, 0);
69 void* alignedInputPtr2 = options.m_CustomAllocatorMap["GpuAcc"]->allocate(totalBytes, 0);
70
71 auto* inputPtr = reinterpret_cast<float*>(alignedInputPtr);
72 std::fill_n(inputPtr, numElements, number);
73 CHECK(inputPtr[0] == 3);
74 CHECK(inputPtr[1] == 3);
75
76 auto* inputPtr2 = reinterpret_cast<float*>(alignedInputPtr2);
77 std::fill_n(inputPtr2, numElements, number);
78 CHECK(inputPtr2[0] == 3);
79 CHECK(inputPtr2[1] == 3);
80
81 // No overlap
82 CHECK(inputPtr[0] == 3);
83 CHECK(inputPtr[1] == 3);
84
85 auto& backendRegistry = armnn::BackendRegistryInstance();
86 backendRegistry.DeregisterAllocator(ClBackend::GetIdStatic());
87}
88
89TEST_CASE("DefaultAllocatorTestMock")
90{
91 // Create ArmNN runtime
92 IRuntime::CreationOptions options; // default options
93 IRuntimePtr run = IRuntime::Create(options);
94
95 // Initialize Mock Backend
96 MockBackendInitialiser initialiser;
97 auto factoryFun = BackendRegistryInstance().GetFactory(MockBackend().GetIdStatic());
98 ARMNN_ASSERT(factoryFun != nullptr);
99 auto backend = factoryFun();
100 auto defaultAllocator = backend->GetDefaultAllocator();
101
102 // GetMemorySourceType
103 CHECK(defaultAllocator->GetMemorySourceType() == MemorySource::Malloc);
104
105 size_t totalBytes = 1 * sizeof(float);
106 // Allocate
107 void* ptr = defaultAllocator->allocate(totalBytes, 0);
108
109 // GetMemoryRegionAtOffset
110 CHECK(defaultAllocator->GetMemoryRegionAtOffset(ptr, 0, 0));
111
112 // Free
113 defaultAllocator->free(ptr);
114
115 // Clean up
116 auto& backendRegistry = armnn::BackendRegistryInstance();
117 backendRegistry.Deregister(MockBackend().GetIdStatic());
118 backendRegistry.DeregisterAllocator(ClBackend::GetIdStatic());
119}
120
121
122}
123
124} // namespace armnn