blob: 2438019878646751ac3aba470dd0ba218f90112a [file] [log] [blame]
Francis Murtaghca49a242021-09-28 15:30:31 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Jim Flynne1fdd282021-10-26 21:26:10 +01006#include <backendsCommon/memoryOptimizerStrategyLibrary/strategies/StrategyValidator.hpp>
Francis Murtaghca49a242021-09-28 15:30:31 +01007
8#include <doctest/doctest.h>
9#include <vector>
10
11using namespace armnn;
12
13TEST_SUITE("MemoryOptimizerStrategyValidatorTestSuite")
14{
15
16// TestMemoryOptimizerStrategy: Create a MemBin and put all blocks in it so the can overlap.
17class TestMemoryOptimizerStrategy : public IMemoryOptimizerStrategy
18{
19public:
20 TestMemoryOptimizerStrategy(MemBlockStrategyType type)
21 : m_Name(std::string("testMemoryOptimizerStrategy"))
22 , m_MemBlockStrategyType(type) {}
23
24 std::string GetName() const override
25 {
26 return m_Name;
27 }
28
29 MemBlockStrategyType GetMemBlockStrategyType() const override
30 {
31 return m_MemBlockStrategyType;
32 }
33
34 std::vector<MemBin> Optimize(std::vector<MemBlock>& memBlocks) override
35 {
36 std::vector<MemBin> memBins;
37 memBins.reserve(memBlocks.size());
38
39 MemBin memBin;
40 memBin.m_MemBlocks.reserve(memBlocks.size());
41 memBin.m_MemSize = 0;
42 for (auto& memBlock : memBlocks)
43 {
44
45 memBin.m_MemSize = memBin.m_MemSize + memBlock.m_MemSize;
46 memBin.m_MemBlocks.push_back(memBlock);
47 }
48 memBins.push_back(memBin);
49
50 return memBins;
51 }
52
53private:
54 std::string m_Name;
55 MemBlockStrategyType m_MemBlockStrategyType;
56};
57
58TEST_CASE("MemoryOptimizerStrategyValidatorTestOverlapX")
59{
60 // create a few memory blocks
61 MemBlock memBlock0(0, 5, 20, 0, 0);
Finn Williamse933c382021-11-10 19:43:51 +000062 MemBlock memBlock1(6, 10, 10, 0, 1);
63 MemBlock memBlock2(11, 15, 15, 0, 2);
64 MemBlock memBlock3(16, 20, 20, 0, 3);
65 MemBlock memBlock4(21, 25, 5, 0, 4);
Francis Murtaghca49a242021-09-28 15:30:31 +010066
67 std::vector<MemBlock> memBlocks;
68 memBlocks.reserve(5);
69 memBlocks.push_back(memBlock0);
70 memBlocks.push_back(memBlock1);
71 memBlocks.push_back(memBlock2);
72 memBlocks.push_back(memBlock3);
73 memBlocks.push_back(memBlock4);
74
75 // Optimize the memory blocks with TestMemoryOptimizerStrategySingle
76 TestMemoryOptimizerStrategy testMemoryOptimizerStrategySingle(MemBlockStrategyType::SingleAxisPacking);
77 auto ptr = std::make_shared<TestMemoryOptimizerStrategy>(testMemoryOptimizerStrategySingle);
Jim Flynne1fdd282021-10-26 21:26:10 +010078 StrategyValidator validator;
79 validator.SetStrategy(ptr);
Francis Murtaghca49a242021-09-28 15:30:31 +010080 // SingleAxisPacking can overlap on X axis.
Jim Flynne1fdd282021-10-26 21:26:10 +010081 CHECK_NOTHROW(validator.Optimize(memBlocks));
Francis Murtaghca49a242021-09-28 15:30:31 +010082
83 // Optimize the memory blocks with TestMemoryOptimizerStrategyMulti
84 TestMemoryOptimizerStrategy testMemoryOptimizerStrategyMulti(MemBlockStrategyType::MultiAxisPacking);
85 auto ptrMulti = std::make_shared<TestMemoryOptimizerStrategy>(testMemoryOptimizerStrategyMulti);
Jim Flynne1fdd282021-10-26 21:26:10 +010086 StrategyValidator validatorMulti;
87 validatorMulti.SetStrategy(ptrMulti);
Francis Murtaghca49a242021-09-28 15:30:31 +010088 // MultiAxisPacking can overlap on X axis.
Jim Flynne1fdd282021-10-26 21:26:10 +010089 CHECK_NOTHROW(validatorMulti.Optimize(memBlocks));
Francis Murtaghca49a242021-09-28 15:30:31 +010090}
91
92TEST_CASE("MemoryOptimizerStrategyValidatorTestOverlapXAndY")
93{
94 // create a few memory blocks
95 MemBlock memBlock0(0, 5, 20, 0, 0);
96 MemBlock memBlock1(0, 10, 10, 0, 1);
97 MemBlock memBlock2(0, 15, 15, 0, 2);
98 MemBlock memBlock3(0, 20, 20, 0, 3);
99 MemBlock memBlock4(0, 25, 5, 0, 4);
100
101 std::vector<MemBlock> memBlocks;
102 memBlocks.reserve(5);
103 memBlocks.push_back(memBlock0);
104 memBlocks.push_back(memBlock1);
105 memBlocks.push_back(memBlock2);
106 memBlocks.push_back(memBlock3);
107 memBlocks.push_back(memBlock4);
108
109 // Optimize the memory blocks with TestMemoryOptimizerStrategySingle
110 TestMemoryOptimizerStrategy testMemoryOptimizerStrategySingle(MemBlockStrategyType::SingleAxisPacking);
111 auto ptr = std::make_shared<TestMemoryOptimizerStrategy>(testMemoryOptimizerStrategySingle);
Jim Flynne1fdd282021-10-26 21:26:10 +0100112 StrategyValidator validator;
113 validator.SetStrategy(ptr);
Francis Murtaghca49a242021-09-28 15:30:31 +0100114 // SingleAxisPacking cannot overlap on both X and Y axis.
Jim Flynne1fdd282021-10-26 21:26:10 +0100115 CHECK_THROWS(validator.Optimize(memBlocks));
Francis Murtaghca49a242021-09-28 15:30:31 +0100116
117 // Optimize the memory blocks with TestMemoryOptimizerStrategyMulti
118 TestMemoryOptimizerStrategy testMemoryOptimizerStrategyMulti(MemBlockStrategyType::MultiAxisPacking);
119 auto ptrMulti = std::make_shared<TestMemoryOptimizerStrategy>(testMemoryOptimizerStrategyMulti);
Jim Flynne1fdd282021-10-26 21:26:10 +0100120 StrategyValidator validatorMulti;
121 validatorMulti.SetStrategy(ptrMulti);
Francis Murtaghca49a242021-09-28 15:30:31 +0100122 // MultiAxisPacking cannot overlap on both X and Y axis.
Jim Flynne1fdd282021-10-26 21:26:10 +0100123 CHECK_THROWS(validatorMulti.Optimize(memBlocks));
Francis Murtaghca49a242021-09-28 15:30:31 +0100124}
125
126TEST_CASE("MemoryOptimizerStrategyValidatorTestOverlapY")
127{
128 // create a few memory blocks
129 MemBlock memBlock0(0, 2, 20, 0, 0);
Finn Williamse933c382021-11-10 19:43:51 +0000130 MemBlock memBlock1(0, 3, 10, 21, 1);
131 MemBlock memBlock2(0, 5, 15, 37, 2);
132 MemBlock memBlock3(0, 6, 20, 58, 3);
133 MemBlock memBlock4(0, 8, 5, 79, 4);
Francis Murtaghca49a242021-09-28 15:30:31 +0100134
135 std::vector<MemBlock> memBlocks;
136 memBlocks.reserve(5);
137 memBlocks.push_back(memBlock0);
138 memBlocks.push_back(memBlock1);
139 memBlocks.push_back(memBlock2);
140 memBlocks.push_back(memBlock3);
141 memBlocks.push_back(memBlock4);
142
143 // Optimize the memory blocks with TestMemoryOptimizerStrategySingle
144 TestMemoryOptimizerStrategy testMemoryOptimizerStrategySingle(MemBlockStrategyType::SingleAxisPacking);
145 auto ptr = std::make_shared<TestMemoryOptimizerStrategy>(testMemoryOptimizerStrategySingle);
Jim Flynne1fdd282021-10-26 21:26:10 +0100146 StrategyValidator validator;
147 validator.SetStrategy(ptr);
Francis Murtaghca49a242021-09-28 15:30:31 +0100148 // SingleAxisPacking cannot overlap on Y axis
Jim Flynne1fdd282021-10-26 21:26:10 +0100149 CHECK_THROWS(validator.Optimize(memBlocks));
Francis Murtaghca49a242021-09-28 15:30:31 +0100150
151 // Optimize the memory blocks with TestMemoryOptimizerStrategyMulti
152 TestMemoryOptimizerStrategy testMemoryOptimizerStrategyMulti(MemBlockStrategyType::MultiAxisPacking);
153 auto ptrMulti = std::make_shared<TestMemoryOptimizerStrategy>(testMemoryOptimizerStrategyMulti);
Jim Flynne1fdd282021-10-26 21:26:10 +0100154 StrategyValidator validatorMulti;
155 validatorMulti.SetStrategy(ptrMulti);
Francis Murtaghca49a242021-09-28 15:30:31 +0100156 // MultiAxisPacking can overlap on Y axis
Jim Flynne1fdd282021-10-26 21:26:10 +0100157 CHECK_NOTHROW(validatorMulti.Optimize(memBlocks));
Francis Murtaghca49a242021-09-28 15:30:31 +0100158}
159
160// TestMemoryOptimizerStrategyDuplicate: Create a MemBin and put all blocks in it duplicating each so validator
161// can check
162class TestMemoryOptimizerStrategyDuplicate : public TestMemoryOptimizerStrategy
163{
164public:
165 TestMemoryOptimizerStrategyDuplicate(MemBlockStrategyType type)
166 : TestMemoryOptimizerStrategy(type)
167 {}
168
169 std::vector<MemBin> Optimize(std::vector<MemBlock>& memBlocks) override
170 {
171 std::vector<MemBin> memBins;
172 memBins.reserve(memBlocks.size());
173
174 MemBin memBin;
175 memBin.m_MemBlocks.reserve(memBlocks.size());
176 for (auto& memBlock : memBlocks)
177 {
178 memBin.m_MemSize = memBin.m_MemSize + memBlock.m_MemSize;
179 memBin.m_MemBlocks.push_back(memBlock);
180 // Put block in twice so it gets found twice
181 memBin.m_MemBlocks.push_back(memBlock);
182 }
183 memBins.push_back(memBin);
184
185 return memBins;
186 }
187};
188
189TEST_CASE("MemoryOptimizerStrategyValidatorTestDuplicateBlocks")
190{
191 // create a few memory blocks
192 MemBlock memBlock0(0, 2, 20, 0, 0);
193 MemBlock memBlock1(2, 3, 10, 20, 1);
194 MemBlock memBlock2(3, 5, 15, 30, 2);
195 MemBlock memBlock3(5, 6, 20, 50, 3);
196 MemBlock memBlock4(7, 8, 5, 70, 4);
197
198 std::vector<MemBlock> memBlocks;
199 memBlocks.reserve(5);
200 memBlocks.push_back(memBlock0);
201 memBlocks.push_back(memBlock1);
202 memBlocks.push_back(memBlock2);
203 memBlocks.push_back(memBlock3);
204 memBlocks.push_back(memBlock4);
205
206 // Optimize the memory blocks with TestMemoryOptimizerStrategySingle
207 // Duplicate strategy is invalid as same block is found twice
208 TestMemoryOptimizerStrategyDuplicate testMemoryOptimizerStrategySingle(MemBlockStrategyType::SingleAxisPacking);
209 auto ptr = std::make_shared<TestMemoryOptimizerStrategyDuplicate>(testMemoryOptimizerStrategySingle);
Jim Flynne1fdd282021-10-26 21:26:10 +0100210 StrategyValidator validator;
211 validator.SetStrategy(ptr);
212 CHECK_THROWS(validator.Optimize(memBlocks));
Francis Murtaghca49a242021-09-28 15:30:31 +0100213
214 // Optimize the memory blocks with TestMemoryOptimizerStrategyMulti
215 TestMemoryOptimizerStrategyDuplicate testMemoryOptimizerStrategyMulti(MemBlockStrategyType::MultiAxisPacking);
216 auto ptrMulti = std::make_shared<TestMemoryOptimizerStrategyDuplicate>(testMemoryOptimizerStrategyMulti);
Jim Flynne1fdd282021-10-26 21:26:10 +0100217 StrategyValidator validatorMulti;
218 validatorMulti.SetStrategy(ptrMulti);
219 CHECK_THROWS(validatorMulti.Optimize(memBlocks));
Francis Murtaghca49a242021-09-28 15:30:31 +0100220}
221
222// TestMemoryOptimizerStrategySkip: Create a MemBin and put all blocks in it skipping every other block so validator
223// can check
224class TestMemoryOptimizerStrategySkip : public TestMemoryOptimizerStrategy
225{
226public:
227 TestMemoryOptimizerStrategySkip(MemBlockStrategyType type)
228 : TestMemoryOptimizerStrategy(type)
229 {}
230
231 std::vector<MemBin> Optimize(std::vector<MemBlock>& memBlocks) override
232 {
233 std::vector<MemBin> memBins;
234 memBins.reserve(memBlocks.size());
235
236 MemBin memBin;
237 memBin.m_MemBlocks.reserve(memBlocks.size());
238 for (unsigned int i = 0; i < memBlocks.size()-1; i+=2)
239 {
240 auto memBlock = memBlocks[i];
241 memBin.m_MemSize = memBin.m_MemSize + memBlock.m_MemSize;
242 memBin.m_MemBlocks.push_back(memBlock);
243 }
244 memBins.push_back(memBin);
245
246 return memBins;
247 }
248};
249
250TEST_CASE("MemoryOptimizerStrategyValidatorTestSkipBlocks")
251{
252 // create a few memory blocks
253 MemBlock memBlock0(0, 2, 20, 0, 0);
254 MemBlock memBlock1(2, 3, 10, 20, 1);
255 MemBlock memBlock2(3, 5, 15, 30, 2);
256 MemBlock memBlock3(5, 6, 20, 50, 3);
257 MemBlock memBlock4(7, 8, 5, 70, 4);
258
259 std::vector<MemBlock> memBlocks;
260 memBlocks.reserve(5);
261 memBlocks.push_back(memBlock0);
262 memBlocks.push_back(memBlock1);
263 memBlocks.push_back(memBlock2);
264 memBlocks.push_back(memBlock3);
265 memBlocks.push_back(memBlock4);
266
267 // Optimize the memory blocks with TestMemoryOptimizerStrategySingle
268 // Skip strategy is invalid as every second block is not found
269 TestMemoryOptimizerStrategySkip testMemoryOptimizerStrategySingle(MemBlockStrategyType::SingleAxisPacking);
270 auto ptr = std::make_shared<TestMemoryOptimizerStrategySkip>(testMemoryOptimizerStrategySingle);
Jim Flynne1fdd282021-10-26 21:26:10 +0100271 StrategyValidator validator;
272 validator.SetStrategy(ptr);
273 CHECK_THROWS(validator.Optimize(memBlocks));
Francis Murtaghca49a242021-09-28 15:30:31 +0100274
275 // Optimize the memory blocks with TestMemoryOptimizerStrategyMulti
276 TestMemoryOptimizerStrategySkip testMemoryOptimizerStrategyMulti(MemBlockStrategyType::MultiAxisPacking);
277 auto ptrMulti = std::make_shared<TestMemoryOptimizerStrategySkip>(testMemoryOptimizerStrategyMulti);
Jim Flynne1fdd282021-10-26 21:26:10 +0100278 StrategyValidator validatorMulti;
279 validatorMulti.SetStrategy(ptrMulti);
280 CHECK_THROWS(validatorMulti.Optimize(memBlocks));
Francis Murtaghca49a242021-09-28 15:30:31 +0100281}
282
283}