blob: fee40fd2576f409397ac6ff565c079934e98c245 [file] [log] [blame]
Colm Donelanc74b1752021-03-12 15:58:48 +00001//
2// Copyright © 2021 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Sadik Armagan1625efc2021-06-10 18:24:34 +01006#include <armnn/utility/Assert.hpp>
7
Colm Donelanc74b1752021-03-12 15:58:48 +00008#include <cl/ClImportTensorHandleFactory.hpp>
9
Sadik Armagan1625efc2021-06-10 18:24:34 +010010#include <doctest/doctest.h>
Colm Donelanc74b1752021-03-12 15:58:48 +000011
Sadik Armagan1625efc2021-06-10 18:24:34 +010012TEST_SUITE("ClImportTensorHandleFactoryTests")
13{
Colm Donelanc74b1752021-03-12 15:58:48 +000014using namespace armnn;
15
Sadik Armagan1625efc2021-06-10 18:24:34 +010016TEST_CASE("ImportTensorFactoryAskedToCreateManagedTensorThrowsException")
Colm Donelanc74b1752021-03-12 15:58:48 +000017{
18 // Create the factory to import tensors.
19 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
20 static_cast<MemorySourceFlags>(MemorySource::Malloc));
21 TensorInfo tensorInfo;
22 // This factory is designed to import the memory of tensors. Asking for a handle that requires
23 // a memory manager should result in an exception.
Sadik Armagan1625efc2021-06-10 18:24:34 +010024 REQUIRE_THROWS_AS(factory.CreateTensorHandle(tensorInfo, true), InvalidArgumentException);
25 REQUIRE_THROWS_AS(factory.CreateTensorHandle(tensorInfo, DataLayout::NCHW, true), InvalidArgumentException);
Colm Donelanc74b1752021-03-12 15:58:48 +000026}
27
Sadik Armagan1625efc2021-06-10 18:24:34 +010028TEST_CASE("ImportTensorFactoryCreateMallocTensorHandle")
Colm Donelanc74b1752021-03-12 15:58:48 +000029{
30 // Create the factory to import tensors.
31 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
32 static_cast<MemorySourceFlags>(MemorySource::Malloc));
33 TensorShape tensorShape{ 6, 7, 8, 9 };
34 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
35 // Start with the TensorInfo factory method. Create an import tensor handle and verify the data is
36 // passed through correctly.
37 auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
Sadik Armagan1625efc2021-06-10 18:24:34 +010038 ARMNN_ASSERT(tensorHandle);
39 ARMNN_ASSERT(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc));
40 ARMNN_ASSERT(tensorHandle->GetShape() == tensorShape);
Colm Donelanc74b1752021-03-12 15:58:48 +000041
42 // Same method but explicitly specifying isManaged = false.
43 tensorHandle = factory.CreateTensorHandle(tensorInfo, false);
Sadik Armagan1625efc2021-06-10 18:24:34 +010044 CHECK(tensorHandle);
45 ARMNN_ASSERT(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc));
46 ARMNN_ASSERT(tensorHandle->GetShape() == tensorShape);
Colm Donelanc74b1752021-03-12 15:58:48 +000047
48 // Now try TensorInfo and DataLayout factory method.
49 tensorHandle = factory.CreateTensorHandle(tensorInfo, DataLayout::NHWC);
Sadik Armagan1625efc2021-06-10 18:24:34 +010050 CHECK(tensorHandle);
51 ARMNN_ASSERT(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc));
52 ARMNN_ASSERT(tensorHandle->GetShape() == tensorShape);
Colm Donelanc74b1752021-03-12 15:58:48 +000053}
54
Sadik Armagan1625efc2021-06-10 18:24:34 +010055TEST_CASE("CreateSubtensorOfImportTensor")
Colm Donelanc74b1752021-03-12 15:58:48 +000056{
57 // Create the factory to import tensors.
58 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
59 static_cast<MemorySourceFlags>(MemorySource::Malloc));
60 // Create a standard inport tensor.
61 TensorShape tensorShape{ 224, 224, 1, 1 };
62 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
63 auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
64 // Use the factory to create a 16x16 sub tensor.
65 TensorShape subTensorShape{ 16, 16, 1, 1 };
66 // Starting at an offset of 1x1.
67 uint32_t origin[4] = { 1, 1, 0, 0 };
68 auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
Sadik Armagan1625efc2021-06-10 18:24:34 +010069 CHECK(subTensor);
70 ARMNN_ASSERT(subTensor->GetShape() == subTensorShape);
71 ARMNN_ASSERT(subTensor->GetParent() == tensorHandle.get());
Colm Donelanc74b1752021-03-12 15:58:48 +000072}
73
Sadik Armagan1625efc2021-06-10 18:24:34 +010074TEST_CASE("CreateSubtensorNonZeroXYIsInvalid")
Colm Donelanc74b1752021-03-12 15:58:48 +000075{
76 // Create the factory to import tensors.
77 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
78 static_cast<MemorySourceFlags>(MemorySource::Malloc));
79 // Create a standard import tensor.
80 TensorShape tensorShape{ 224, 224, 1, 1 };
81 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
82 auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
83 // Use the factory to create a 16x16 sub tensor.
84 TensorShape subTensorShape{ 16, 16, 1, 1 };
85 // This looks a bit backwards because of how Cl specifies tensors. Essentially we want to trigger our
86 // check "(coords.x() != 0 || coords.y() != 0)"
87 uint32_t origin[4] = { 0, 0, 1, 1 };
88 auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
89 // We expect a nullptr.
Sadik Armagan1625efc2021-06-10 18:24:34 +010090 ARMNN_ASSERT(subTensor == nullptr);
Colm Donelanc74b1752021-03-12 15:58:48 +000091}
92
Sadik Armagan1625efc2021-06-10 18:24:34 +010093TEST_CASE("CreateSubtensorXYMustMatchParent")
Colm Donelanc74b1752021-03-12 15:58:48 +000094{
95 // Create the factory to import tensors.
96 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
97 static_cast<MemorySourceFlags>(MemorySource::Malloc));
98 // Create a standard import tensor.
99 TensorShape tensorShape{ 224, 224, 1, 1 };
100 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
101 auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
102 // Use the factory to create a 16x16 sub tensor but make the CL x and y axis different.
103 TensorShape subTensorShape{ 16, 16, 2, 2 };
104 // We want to trigger our ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y()))
105 uint32_t origin[4] = { 1, 1, 0, 0 };
106 auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
107 // We expect a nullptr.
Sadik Armagan1625efc2021-06-10 18:24:34 +0100108 ARMNN_ASSERT(subTensor == nullptr);
Colm Donelanc74b1752021-03-12 15:58:48 +0000109}
110
Sadik Armagan1625efc2021-06-10 18:24:34 +0100111TEST_CASE("CreateSubtensorMustBeSmallerThanParent")
Colm Donelanc74b1752021-03-12 15:58:48 +0000112{
113 // Create the factory to import tensors.
114 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
115 static_cast<MemorySourceFlags>(MemorySource::Malloc));
116 // Create a standard import tensor.
117 TensorShape tensorShape{ 224, 224, 1, 1 };
118 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
119 auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
120 // Ask for a subtensor that's the same size as the parent.
121 TensorShape subTensorShape{ 224, 224, 1, 1 };
122 uint32_t origin[4] = { 1, 1, 0, 0 };
123 // This should result in a nullptr.
124 auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100125 ARMNN_ASSERT(subTensor == nullptr);
Colm Donelanc74b1752021-03-12 15:58:48 +0000126}
127
Sadik Armagan1625efc2021-06-10 18:24:34 +0100128}