blob: af09e4ed5645de9151a1fb54129291d10ef56662 [file] [log] [blame]
Francis Murtaghca49a242021-09-28 15:30:31 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <backendsCommon/memoryOptimizationStrategies/MemoryOptimizerStrategyValidator.hpp>
7
8#include <doctest/doctest.h>
9#include <vector>
10
11using namespace armnn;
12
13TEST_SUITE("MemoryOptimizerStrategyValidatorTestSuite")
14{
15
16// TestMemoryOptimizerStrategy: Create a MemBin and put all blocks in it so the can overlap.
17class TestMemoryOptimizerStrategy : public IMemoryOptimizerStrategy
18{
19public:
20 TestMemoryOptimizerStrategy(MemBlockStrategyType type)
21 : m_Name(std::string("testMemoryOptimizerStrategy"))
22 , m_MemBlockStrategyType(type) {}
23
24 std::string GetName() const override
25 {
26 return m_Name;
27 }
28
29 MemBlockStrategyType GetMemBlockStrategyType() const override
30 {
31 return m_MemBlockStrategyType;
32 }
33
34 std::vector<MemBin> Optimize(std::vector<MemBlock>& memBlocks) override
35 {
36 std::vector<MemBin> memBins;
37 memBins.reserve(memBlocks.size());
38
39 MemBin memBin;
40 memBin.m_MemBlocks.reserve(memBlocks.size());
41 memBin.m_MemSize = 0;
42 for (auto& memBlock : memBlocks)
43 {
44
45 memBin.m_MemSize = memBin.m_MemSize + memBlock.m_MemSize;
46 memBin.m_MemBlocks.push_back(memBlock);
47 }
48 memBins.push_back(memBin);
49
50 return memBins;
51 }
52
53private:
54 std::string m_Name;
55 MemBlockStrategyType m_MemBlockStrategyType;
56};
57
58TEST_CASE("MemoryOptimizerStrategyValidatorTestOverlapX")
59{
60 // create a few memory blocks
61 MemBlock memBlock0(0, 5, 20, 0, 0);
62 MemBlock memBlock1(5, 10, 10, 0, 1);
63 MemBlock memBlock2(10, 15, 15, 0, 2);
64 MemBlock memBlock3(15, 20, 20, 0, 3);
65 MemBlock memBlock4(20, 25, 5, 0, 4);
66
67 std::vector<MemBlock> memBlocks;
68 memBlocks.reserve(5);
69 memBlocks.push_back(memBlock0);
70 memBlocks.push_back(memBlock1);
71 memBlocks.push_back(memBlock2);
72 memBlocks.push_back(memBlock3);
73 memBlocks.push_back(memBlock4);
74
75 // Optimize the memory blocks with TestMemoryOptimizerStrategySingle
76 TestMemoryOptimizerStrategy testMemoryOptimizerStrategySingle(MemBlockStrategyType::SingleAxisPacking);
77 auto ptr = std::make_shared<TestMemoryOptimizerStrategy>(testMemoryOptimizerStrategySingle);
78 MemoryOptimizerValidator validator(std::move(ptr));
79 // SingleAxisPacking can overlap on X axis.
80 CHECK(validator.Validate(memBlocks));
81
82 // Optimize the memory blocks with TestMemoryOptimizerStrategyMulti
83 TestMemoryOptimizerStrategy testMemoryOptimizerStrategyMulti(MemBlockStrategyType::MultiAxisPacking);
84 auto ptrMulti = std::make_shared<TestMemoryOptimizerStrategy>(testMemoryOptimizerStrategyMulti);
85 MemoryOptimizerValidator validatorMulti(std::move(ptrMulti));
86 // MultiAxisPacking can overlap on X axis.
87 CHECK(validatorMulti.Validate(memBlocks));
88}
89
90TEST_CASE("MemoryOptimizerStrategyValidatorTestOverlapXAndY")
91{
92 // create a few memory blocks
93 MemBlock memBlock0(0, 5, 20, 0, 0);
94 MemBlock memBlock1(0, 10, 10, 0, 1);
95 MemBlock memBlock2(0, 15, 15, 0, 2);
96 MemBlock memBlock3(0, 20, 20, 0, 3);
97 MemBlock memBlock4(0, 25, 5, 0, 4);
98
99 std::vector<MemBlock> memBlocks;
100 memBlocks.reserve(5);
101 memBlocks.push_back(memBlock0);
102 memBlocks.push_back(memBlock1);
103 memBlocks.push_back(memBlock2);
104 memBlocks.push_back(memBlock3);
105 memBlocks.push_back(memBlock4);
106
107 // Optimize the memory blocks with TestMemoryOptimizerStrategySingle
108 TestMemoryOptimizerStrategy testMemoryOptimizerStrategySingle(MemBlockStrategyType::SingleAxisPacking);
109 auto ptr = std::make_shared<TestMemoryOptimizerStrategy>(testMemoryOptimizerStrategySingle);
110 MemoryOptimizerValidator validator(std::move(ptr));
111 // SingleAxisPacking cannot overlap on both X and Y axis.
112 CHECK(!validator.Validate(memBlocks));
113
114 // Optimize the memory blocks with TestMemoryOptimizerStrategyMulti
115 TestMemoryOptimizerStrategy testMemoryOptimizerStrategyMulti(MemBlockStrategyType::MultiAxisPacking);
116 auto ptrMulti = std::make_shared<TestMemoryOptimizerStrategy>(testMemoryOptimizerStrategyMulti);
117 MemoryOptimizerValidator validatorMulti(std::move(ptrMulti));
118 // MultiAxisPacking cannot overlap on both X and Y axis.
119 CHECK(!validatorMulti.Validate(memBlocks));
120}
121
122TEST_CASE("MemoryOptimizerStrategyValidatorTestOverlapY")
123{
124 // create a few memory blocks
125 MemBlock memBlock0(0, 2, 20, 0, 0);
126 MemBlock memBlock1(0, 3, 10, 20, 1);
127 MemBlock memBlock2(0, 5, 15, 30, 2);
128 MemBlock memBlock3(0, 6, 20, 50, 3);
129 MemBlock memBlock4(0, 8, 5, 70, 4);
130
131 std::vector<MemBlock> memBlocks;
132 memBlocks.reserve(5);
133 memBlocks.push_back(memBlock0);
134 memBlocks.push_back(memBlock1);
135 memBlocks.push_back(memBlock2);
136 memBlocks.push_back(memBlock3);
137 memBlocks.push_back(memBlock4);
138
139 // Optimize the memory blocks with TestMemoryOptimizerStrategySingle
140 TestMemoryOptimizerStrategy testMemoryOptimizerStrategySingle(MemBlockStrategyType::SingleAxisPacking);
141 auto ptr = std::make_shared<TestMemoryOptimizerStrategy>(testMemoryOptimizerStrategySingle);
142 MemoryOptimizerValidator validator(std::move(ptr));
143 // SingleAxisPacking cannot overlap on Y axis
144 CHECK(!validator.Validate(memBlocks));
145
146 // Optimize the memory blocks with TestMemoryOptimizerStrategyMulti
147 TestMemoryOptimizerStrategy testMemoryOptimizerStrategyMulti(MemBlockStrategyType::MultiAxisPacking);
148 auto ptrMulti = std::make_shared<TestMemoryOptimizerStrategy>(testMemoryOptimizerStrategyMulti);
149 MemoryOptimizerValidator validatorMulti(std::move(ptrMulti));
150 // MultiAxisPacking can overlap on Y axis
151 CHECK(validatorMulti.Validate(memBlocks));
152}
153
154// TestMemoryOptimizerStrategyDuplicate: Create a MemBin and put all blocks in it duplicating each so validator
155// can check
156class TestMemoryOptimizerStrategyDuplicate : public TestMemoryOptimizerStrategy
157{
158public:
159 TestMemoryOptimizerStrategyDuplicate(MemBlockStrategyType type)
160 : TestMemoryOptimizerStrategy(type)
161 {}
162
163 std::vector<MemBin> Optimize(std::vector<MemBlock>& memBlocks) override
164 {
165 std::vector<MemBin> memBins;
166 memBins.reserve(memBlocks.size());
167
168 MemBin memBin;
169 memBin.m_MemBlocks.reserve(memBlocks.size());
170 for (auto& memBlock : memBlocks)
171 {
172 memBin.m_MemSize = memBin.m_MemSize + memBlock.m_MemSize;
173 memBin.m_MemBlocks.push_back(memBlock);
174 // Put block in twice so it gets found twice
175 memBin.m_MemBlocks.push_back(memBlock);
176 }
177 memBins.push_back(memBin);
178
179 return memBins;
180 }
181};
182
183TEST_CASE("MemoryOptimizerStrategyValidatorTestDuplicateBlocks")
184{
185 // create a few memory blocks
186 MemBlock memBlock0(0, 2, 20, 0, 0);
187 MemBlock memBlock1(2, 3, 10, 20, 1);
188 MemBlock memBlock2(3, 5, 15, 30, 2);
189 MemBlock memBlock3(5, 6, 20, 50, 3);
190 MemBlock memBlock4(7, 8, 5, 70, 4);
191
192 std::vector<MemBlock> memBlocks;
193 memBlocks.reserve(5);
194 memBlocks.push_back(memBlock0);
195 memBlocks.push_back(memBlock1);
196 memBlocks.push_back(memBlock2);
197 memBlocks.push_back(memBlock3);
198 memBlocks.push_back(memBlock4);
199
200 // Optimize the memory blocks with TestMemoryOptimizerStrategySingle
201 // Duplicate strategy is invalid as same block is found twice
202 TestMemoryOptimizerStrategyDuplicate testMemoryOptimizerStrategySingle(MemBlockStrategyType::SingleAxisPacking);
203 auto ptr = std::make_shared<TestMemoryOptimizerStrategyDuplicate>(testMemoryOptimizerStrategySingle);
204 MemoryOptimizerValidator validator(std::move(ptr));
205 CHECK(!validator.Validate(memBlocks));
206
207 // Optimize the memory blocks with TestMemoryOptimizerStrategyMulti
208 TestMemoryOptimizerStrategyDuplicate testMemoryOptimizerStrategyMulti(MemBlockStrategyType::MultiAxisPacking);
209 auto ptrMulti = std::make_shared<TestMemoryOptimizerStrategyDuplicate>(testMemoryOptimizerStrategyMulti);
210 MemoryOptimizerValidator validatorMulti(std::move(ptrMulti));
211 CHECK(!validatorMulti.Validate(memBlocks));
212}
213
214// TestMemoryOptimizerStrategySkip: Create a MemBin and put all blocks in it skipping every other block so validator
215// can check
216class TestMemoryOptimizerStrategySkip : public TestMemoryOptimizerStrategy
217{
218public:
219 TestMemoryOptimizerStrategySkip(MemBlockStrategyType type)
220 : TestMemoryOptimizerStrategy(type)
221 {}
222
223 std::vector<MemBin> Optimize(std::vector<MemBlock>& memBlocks) override
224 {
225 std::vector<MemBin> memBins;
226 memBins.reserve(memBlocks.size());
227
228 MemBin memBin;
229 memBin.m_MemBlocks.reserve(memBlocks.size());
230 for (unsigned int i = 0; i < memBlocks.size()-1; i+=2)
231 {
232 auto memBlock = memBlocks[i];
233 memBin.m_MemSize = memBin.m_MemSize + memBlock.m_MemSize;
234 memBin.m_MemBlocks.push_back(memBlock);
235 }
236 memBins.push_back(memBin);
237
238 return memBins;
239 }
240};
241
242TEST_CASE("MemoryOptimizerStrategyValidatorTestSkipBlocks")
243{
244 // create a few memory blocks
245 MemBlock memBlock0(0, 2, 20, 0, 0);
246 MemBlock memBlock1(2, 3, 10, 20, 1);
247 MemBlock memBlock2(3, 5, 15, 30, 2);
248 MemBlock memBlock3(5, 6, 20, 50, 3);
249 MemBlock memBlock4(7, 8, 5, 70, 4);
250
251 std::vector<MemBlock> memBlocks;
252 memBlocks.reserve(5);
253 memBlocks.push_back(memBlock0);
254 memBlocks.push_back(memBlock1);
255 memBlocks.push_back(memBlock2);
256 memBlocks.push_back(memBlock3);
257 memBlocks.push_back(memBlock4);
258
259 // Optimize the memory blocks with TestMemoryOptimizerStrategySingle
260 // Skip strategy is invalid as every second block is not found
261 TestMemoryOptimizerStrategySkip testMemoryOptimizerStrategySingle(MemBlockStrategyType::SingleAxisPacking);
262 auto ptr = std::make_shared<TestMemoryOptimizerStrategySkip>(testMemoryOptimizerStrategySingle);
263 MemoryOptimizerValidator validator(std::move(ptr));
264 CHECK(!validator.Validate(memBlocks));
265
266 // Optimize the memory blocks with TestMemoryOptimizerStrategyMulti
267 TestMemoryOptimizerStrategySkip testMemoryOptimizerStrategyMulti(MemBlockStrategyType::MultiAxisPacking);
268 auto ptrMulti = std::make_shared<TestMemoryOptimizerStrategySkip>(testMemoryOptimizerStrategyMulti);
269 MemoryOptimizerValidator validatorMulti(std::move(ptrMulti));
270 CHECK(!validatorMulti.Validate(memBlocks));
271}
272
273}