blob: 0c6a9c6e7b2344412b897bdf8d266cade472b501 [file] [log] [blame]
Colm Donelanc74b1752021-03-12 15:58:48 +00001//
2// Copyright © 2021 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <cl/ClImportTensorHandleFactory.hpp>
7
8#include <boost/test/unit_test.hpp>
9
10BOOST_AUTO_TEST_SUITE(ClImportTensorHandleFactoryTests)
11using namespace armnn;
12
13BOOST_AUTO_TEST_CASE(ImportTensorFactoryAskedToCreateManagedTensorThrowsException)
14{
15 // Create the factory to import tensors.
16 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
17 static_cast<MemorySourceFlags>(MemorySource::Malloc));
18 TensorInfo tensorInfo;
19 // This factory is designed to import the memory of tensors. Asking for a handle that requires
20 // a memory manager should result in an exception.
21 BOOST_REQUIRE_THROW(factory.CreateTensorHandle(tensorInfo, true), InvalidArgumentException);
22 BOOST_REQUIRE_THROW(factory.CreateTensorHandle(tensorInfo, DataLayout::NCHW, true), InvalidArgumentException);
23}
24
25BOOST_AUTO_TEST_CASE(ImportTensorFactoryCreateMallocTensorHandle)
26{
27 // Create the factory to import tensors.
28 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
29 static_cast<MemorySourceFlags>(MemorySource::Malloc));
30 TensorShape tensorShape{ 6, 7, 8, 9 };
31 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
32 // Start with the TensorInfo factory method. Create an import tensor handle and verify the data is
33 // passed through correctly.
34 auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
35 BOOST_ASSERT(tensorHandle);
36 BOOST_ASSERT(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc));
37 BOOST_ASSERT(tensorHandle->GetShape() == tensorShape);
38
39 // Same method but explicitly specifying isManaged = false.
40 tensorHandle = factory.CreateTensorHandle(tensorInfo, false);
41 BOOST_CHECK(tensorHandle);
42 BOOST_ASSERT(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc));
43 BOOST_ASSERT(tensorHandle->GetShape() == tensorShape);
44
45 // Now try TensorInfo and DataLayout factory method.
46 tensorHandle = factory.CreateTensorHandle(tensorInfo, DataLayout::NHWC);
47 BOOST_CHECK(tensorHandle);
48 BOOST_ASSERT(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc));
49 BOOST_ASSERT(tensorHandle->GetShape() == tensorShape);
50}
51
52BOOST_AUTO_TEST_CASE(CreateSubtensorOfImportTensor)
53{
54 // Create the factory to import tensors.
55 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
56 static_cast<MemorySourceFlags>(MemorySource::Malloc));
57 // Create a standard inport tensor.
58 TensorShape tensorShape{ 224, 224, 1, 1 };
59 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
60 auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
61 // Use the factory to create a 16x16 sub tensor.
62 TensorShape subTensorShape{ 16, 16, 1, 1 };
63 // Starting at an offset of 1x1.
64 uint32_t origin[4] = { 1, 1, 0, 0 };
65 auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
66 BOOST_CHECK(subTensor);
67 BOOST_ASSERT(subTensor->GetShape() == subTensorShape);
68 BOOST_ASSERT(subTensor->GetParent() == tensorHandle.get());
69}
70
71BOOST_AUTO_TEST_CASE(CreateSubtensorNonZeroXYIsInvalid)
72{
73 // Create the factory to import tensors.
74 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
75 static_cast<MemorySourceFlags>(MemorySource::Malloc));
76 // Create a standard import tensor.
77 TensorShape tensorShape{ 224, 224, 1, 1 };
78 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
79 auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
80 // Use the factory to create a 16x16 sub tensor.
81 TensorShape subTensorShape{ 16, 16, 1, 1 };
82 // This looks a bit backwards because of how Cl specifies tensors. Essentially we want to trigger our
83 // check "(coords.x() != 0 || coords.y() != 0)"
84 uint32_t origin[4] = { 0, 0, 1, 1 };
85 auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
86 // We expect a nullptr.
87 BOOST_ASSERT(subTensor == nullptr);
88}
89
90BOOST_AUTO_TEST_CASE(CreateSubtensorXYMustMatchParent)
91{
92 // Create the factory to import tensors.
93 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
94 static_cast<MemorySourceFlags>(MemorySource::Malloc));
95 // Create a standard import tensor.
96 TensorShape tensorShape{ 224, 224, 1, 1 };
97 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
98 auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
99 // Use the factory to create a 16x16 sub tensor but make the CL x and y axis different.
100 TensorShape subTensorShape{ 16, 16, 2, 2 };
101 // We want to trigger our ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y()))
102 uint32_t origin[4] = { 1, 1, 0, 0 };
103 auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
104 // We expect a nullptr.
105 BOOST_ASSERT(subTensor == nullptr);
106}
107
108BOOST_AUTO_TEST_CASE(CreateSubtensorMustBeSmallerThanParent)
109{
110 // Create the factory to import tensors.
111 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
112 static_cast<MemorySourceFlags>(MemorySource::Malloc));
113 // Create a standard import tensor.
114 TensorShape tensorShape{ 224, 224, 1, 1 };
115 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
116 auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
117 // Ask for a subtensor that's the same size as the parent.
118 TensorShape subTensorShape{ 224, 224, 1, 1 };
119 uint32_t origin[4] = { 1, 1, 0, 0 };
120 // This should result in a nullptr.
121 auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
122 BOOST_ASSERT(subTensor == nullptr);
123}
124
125BOOST_AUTO_TEST_SUITE_END()