blob: 46be3a122d216e169ff4f01828e687280bbab676 [file] [log] [blame]
Colm Donelanc74b1752021-03-12 15:58:48 +00001//
Colm Donelan68c60e92024-02-21 15:58:35 +00002// Copyright © 2021, 2024 Arm Ltd. All rights reserved.
Colm Donelanc74b1752021-03-12 15:58:48 +00003// SPDX-License-Identifier: MIT
4//
5
6#include <cl/ClImportTensorHandleFactory.hpp>
7
Sadik Armagan1625efc2021-06-10 18:24:34 +01008#include <doctest/doctest.h>
Colm Donelanc74b1752021-03-12 15:58:48 +00009
Sadik Armagan1625efc2021-06-10 18:24:34 +010010TEST_SUITE("ClImportTensorHandleFactoryTests")
11{
Colm Donelanc74b1752021-03-12 15:58:48 +000012using namespace armnn;
13
Sadik Armagan1625efc2021-06-10 18:24:34 +010014TEST_CASE("ImportTensorFactoryAskedToCreateManagedTensorThrowsException")
Colm Donelanc74b1752021-03-12 15:58:48 +000015{
16 // Create the factory to import tensors.
17 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
18 static_cast<MemorySourceFlags>(MemorySource::Malloc));
19 TensorInfo tensorInfo;
20 // This factory is designed to import the memory of tensors. Asking for a handle that requires
21 // a memory manager should result in an exception.
Sadik Armagan1625efc2021-06-10 18:24:34 +010022 REQUIRE_THROWS_AS(factory.CreateTensorHandle(tensorInfo, true), InvalidArgumentException);
23 REQUIRE_THROWS_AS(factory.CreateTensorHandle(tensorInfo, DataLayout::NCHW, true), InvalidArgumentException);
Colm Donelanc74b1752021-03-12 15:58:48 +000024}
25
Sadik Armagan1625efc2021-06-10 18:24:34 +010026TEST_CASE("ImportTensorFactoryCreateMallocTensorHandle")
Colm Donelanc74b1752021-03-12 15:58:48 +000027{
28 // Create the factory to import tensors.
29 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
30 static_cast<MemorySourceFlags>(MemorySource::Malloc));
31 TensorShape tensorShape{ 6, 7, 8, 9 };
32 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
33 // Start with the TensorInfo factory method. Create an import tensor handle and verify the data is
34 // passed through correctly.
35 auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
Colm Donelan68c60e92024-02-21 15:58:35 +000036 CHECK(tensorHandle);
37 CHECK(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc));
38 CHECK(tensorHandle->GetShape() == tensorShape);
Colm Donelanc74b1752021-03-12 15:58:48 +000039
40 // Same method but explicitly specifying isManaged = false.
41 tensorHandle = factory.CreateTensorHandle(tensorInfo, false);
Sadik Armagan1625efc2021-06-10 18:24:34 +010042 CHECK(tensorHandle);
Colm Donelan68c60e92024-02-21 15:58:35 +000043 CHECK(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc));
44 CHECK(tensorHandle->GetShape() == tensorShape);
Colm Donelanc74b1752021-03-12 15:58:48 +000045
46 // Now try TensorInfo and DataLayout factory method.
47 tensorHandle = factory.CreateTensorHandle(tensorInfo, DataLayout::NHWC);
Sadik Armagan1625efc2021-06-10 18:24:34 +010048 CHECK(tensorHandle);
Colm Donelan68c60e92024-02-21 15:58:35 +000049 CHECK(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc));
50 CHECK(tensorHandle->GetShape() == tensorShape);
Colm Donelanc74b1752021-03-12 15:58:48 +000051}
52
Sadik Armagan1625efc2021-06-10 18:24:34 +010053TEST_CASE("CreateSubtensorOfImportTensor")
Colm Donelanc74b1752021-03-12 15:58:48 +000054{
55 // Create the factory to import tensors.
56 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
57 static_cast<MemorySourceFlags>(MemorySource::Malloc));
58 // Create a standard inport tensor.
59 TensorShape tensorShape{ 224, 224, 1, 1 };
60 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
61 auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
62 // Use the factory to create a 16x16 sub tensor.
63 TensorShape subTensorShape{ 16, 16, 1, 1 };
64 // Starting at an offset of 1x1.
65 uint32_t origin[4] = { 1, 1, 0, 0 };
66 auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
Sadik Armagan1625efc2021-06-10 18:24:34 +010067 CHECK(subTensor);
Colm Donelan68c60e92024-02-21 15:58:35 +000068 CHECK(subTensor->GetShape() == subTensorShape);
69 CHECK(subTensor->GetParent() == tensorHandle.get());
Colm Donelanc74b1752021-03-12 15:58:48 +000070}
71
Sadik Armagan1625efc2021-06-10 18:24:34 +010072TEST_CASE("CreateSubtensorNonZeroXYIsInvalid")
Colm Donelanc74b1752021-03-12 15:58:48 +000073{
74 // Create the factory to import tensors.
75 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
76 static_cast<MemorySourceFlags>(MemorySource::Malloc));
77 // Create a standard import tensor.
78 TensorShape tensorShape{ 224, 224, 1, 1 };
79 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
80 auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
81 // Use the factory to create a 16x16 sub tensor.
82 TensorShape subTensorShape{ 16, 16, 1, 1 };
83 // This looks a bit backwards because of how Cl specifies tensors. Essentially we want to trigger our
84 // check "(coords.x() != 0 || coords.y() != 0)"
85 uint32_t origin[4] = { 0, 0, 1, 1 };
86 auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
87 // We expect a nullptr.
Colm Donelan68c60e92024-02-21 15:58:35 +000088 CHECK(subTensor == nullptr);
Colm Donelanc74b1752021-03-12 15:58:48 +000089}
90
Sadik Armagan1625efc2021-06-10 18:24:34 +010091TEST_CASE("CreateSubtensorXYMustMatchParent")
Colm Donelanc74b1752021-03-12 15:58:48 +000092{
93 // Create the factory to import tensors.
94 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
95 static_cast<MemorySourceFlags>(MemorySource::Malloc));
96 // Create a standard import tensor.
97 TensorShape tensorShape{ 224, 224, 1, 1 };
98 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
99 auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
100 // Use the factory to create a 16x16 sub tensor but make the CL x and y axis different.
101 TensorShape subTensorShape{ 16, 16, 2, 2 };
102 // We want to trigger our ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y()))
103 uint32_t origin[4] = { 1, 1, 0, 0 };
104 auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
105 // We expect a nullptr.
Colm Donelan68c60e92024-02-21 15:58:35 +0000106 CHECK(subTensor == nullptr);
Colm Donelanc74b1752021-03-12 15:58:48 +0000107}
108
Sadik Armagan1625efc2021-06-10 18:24:34 +0100109TEST_CASE("CreateSubtensorMustBeSmallerThanParent")
Colm Donelanc74b1752021-03-12 15:58:48 +0000110{
111 // Create the factory to import tensors.
112 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
113 static_cast<MemorySourceFlags>(MemorySource::Malloc));
114 // Create a standard import tensor.
115 TensorShape tensorShape{ 224, 224, 1, 1 };
116 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
117 auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
118 // Ask for a subtensor that's the same size as the parent.
119 TensorShape subTensorShape{ 224, 224, 1, 1 };
120 uint32_t origin[4] = { 1, 1, 0, 0 };
121 // This should result in a nullptr.
122 auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
Colm Donelan68c60e92024-02-21 15:58:35 +0000123 CHECK(subTensor == nullptr);
Colm Donelanc74b1752021-03-12 15:58:48 +0000124}
125
Sadik Armagan1625efc2021-06-10 18:24:34 +0100126}