blob: bfb74af801c0d894628b593f99ec8fb91115f6f5 [file] [log] [blame]
David Monahane4a41dc2021-04-14 16:55:36 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <arm_compute/runtime/CL/functions/CLActivationLayer.h>
7
8#include <cl/ClImportTensorHandle.hpp>
9#include <cl/ClImportTensorHandleFactory.hpp>
10#include <cl/test/ClContextControlFixture.hpp>
11
12#include <boost/test/unit_test.hpp>
13
14using namespace armnn;
15
16BOOST_AUTO_TEST_SUITE(ClImportTensorHandleTests)
17
18BOOST_FIXTURE_TEST_CASE(ClMallocImport, ClContextControlFixture)
19{
20 ClImportTensorHandleFactory handleFactory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
21 static_cast<MemorySourceFlags>(MemorySource::Malloc));
22
23 TensorInfo info({ 1, 24, 16, 3 }, DataType::Float32);
24 unsigned int numElements = info.GetNumElements();
25
26 // create TensorHandle for memory import
27 auto handle = handleFactory.CreateTensorHandle(info);
28
29 // Get CLtensor
30 arm_compute::CLTensor& tensor = PolymorphicDowncast<ClImportTensorHandle*>(handle.get())->GetTensor();
31
32 // Create and configure activation function
33 const arm_compute::ActivationLayerInfo act_info(arm_compute::ActivationLayerInfo::ActivationFunction::RELU);
34 arm_compute::CLActivationLayer act_func;
35 act_func.configure(&tensor, nullptr, act_info);
36
37 // Allocate user memory
38 const size_t totalBytes = tensor.info()->total_size();
39 const size_t alignment =
40 arm_compute::CLKernelLibrary::get().get_device().getInfo<CL_DEVICE_GLOBAL_MEM_CACHELINE_SIZE>();
41 size_t space = totalBytes + alignment;
42 auto testData = std::make_unique<uint8_t[]>(space);
43 void* alignedPtr = testData.get();
44 BOOST_CHECK(std::align(alignment, totalBytes, alignedPtr, space));
45
46 // Import memory
47 BOOST_CHECK(handle->Import(alignedPtr, armnn::MemorySource::Malloc));
48
49 // Input with negative values
50 auto* typedPtr = reinterpret_cast<float*>(alignedPtr);
51 std::fill_n(typedPtr, numElements, -5.0f);
52
53 // Execute function and sync
54 act_func.run();
55 arm_compute::CLScheduler::get().sync();
56
57 // Validate result by checking that the output has no negative values
58 for(unsigned int i = 0; i < numElements; ++i)
59 {
60 BOOST_ASSERT(typedPtr[i] >= 0);
61 }
62}
63
64BOOST_FIXTURE_TEST_CASE(ClIncorrectMemorySourceImport, ClContextControlFixture)
65{
66 ClImportTensorHandleFactory handleFactory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
67 static_cast<MemorySourceFlags>(MemorySource::Malloc));
68
69 TensorInfo info({ 1, 24, 16, 3 }, DataType::Float32);
70
71 // create TensorHandle for memory import
72 auto handle = handleFactory.CreateTensorHandle(info);
73
74 // Get CLtensor
75 arm_compute::CLTensor& tensor = PolymorphicDowncast<ClImportTensorHandle*>(handle.get())->GetTensor();
76
77 // Allocate user memory
78 const size_t totalBytes = tensor.info()->total_size();
79 const size_t alignment =
80 arm_compute::CLKernelLibrary::get().get_device().getInfo<CL_DEVICE_GLOBAL_MEM_CACHELINE_SIZE>();
81 size_t space = totalBytes + alignment;
82 auto testData = std::make_unique<uint8_t[]>(space);
83 void* alignedPtr = testData.get();
84 BOOST_CHECK(std::align(alignment, totalBytes, alignedPtr, space));
85
86 // Import memory
87 BOOST_CHECK_THROW(handle->Import(alignedPtr, armnn::MemorySource::Undefined), MemoryImportException);
88}
89
90BOOST_FIXTURE_TEST_CASE(ClInvalidMemorySourceImport, ClContextControlFixture)
91{
92 MemorySource invalidMemSource = static_cast<MemorySource>(256);
93 ClImportTensorHandleFactory handleFactory(static_cast<MemorySourceFlags>(invalidMemSource),
94 static_cast<MemorySourceFlags>(invalidMemSource));
95
96 TensorInfo info({ 1, 2, 2, 1 }, DataType::Float32);
97
98 // create TensorHandle for memory import
99 auto handle = handleFactory.CreateTensorHandle(info);
100
101 // Allocate user memory
102 std::vector<float> inputData
103 {
104 1.0f, 2.0f, 3.0f, 4.0f
105 };
106
107 // Import non-support memory
108 BOOST_CHECK_THROW(handle->Import(inputData.data(), invalidMemSource), MemoryImportException);
109}
110
111BOOST_AUTO_TEST_SUITE_END()