blob: aa0b3597fa8914e46d9aa8cc1cf91c951d454543 [file] [log] [blame]
FrancisMurtaghf08876f2019-02-04 15:41:17 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include <armnn/ArmNN.hpp>
8#include "TestLayerVisitor.hpp"
9#include <boost/test/unit_test.hpp>
10
11namespace armnn
12{
13
14// Concrete TestLayerVisitor subclasses for layers taking Descriptor argument with overridden VisitLayer methods
15class TestPermuteLayerVisitor : public TestLayerVisitor
16{
17private:
18 const PermuteDescriptor m_VisitorDescriptor;
19
20public:
21 explicit TestPermuteLayerVisitor(const PermuteDescriptor& permuteDescriptor, const char* name = nullptr)
22 : TestLayerVisitor(name)
23 , m_VisitorDescriptor(permuteDescriptor.m_DimMappings)
24 {};
25
26 void CheckDescriptor(const PermuteDescriptor& permuteDescriptor)
27 {
28 if (permuteDescriptor.m_DimMappings.GetSize() == m_VisitorDescriptor.m_DimMappings.GetSize())
29 {
30 for (unsigned int i = 0; i < permuteDescriptor.m_DimMappings.GetSize(); ++i)
31 {
32 BOOST_CHECK_EQUAL(permuteDescriptor.m_DimMappings[i], m_VisitorDescriptor.m_DimMappings[i]);
33 }
34 }
35 else
36 {
37 BOOST_ERROR("Unequal vector size for batchToSpaceNdDescriptor m_DimMappings.");
38 }
39 };
40
41 void VisitPermuteLayer(const IConnectableLayer* layer,
42 const PermuteDescriptor& permuteDescriptor,
43 const char* name = nullptr) override
44 {
45 CheckLayerPointer(layer);
46 CheckDescriptor(permuteDescriptor);
47 CheckLayerName(name);
48 };
49};
50
51class TestBatchToSpaceNdLayerVisitor : public TestLayerVisitor
52{
53private:
54 BatchToSpaceNdDescriptor m_VisitorDescriptor;
55
56public:
57 explicit TestBatchToSpaceNdLayerVisitor(const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
58 const char* name = nullptr)
59 : TestLayerVisitor(name)
60 , m_VisitorDescriptor(batchToSpaceNdDescriptor.m_BlockShape, batchToSpaceNdDescriptor.m_Crops)
61 {
62 m_VisitorDescriptor.m_DataLayout = batchToSpaceNdDescriptor.m_DataLayout;
63 };
64
65 void CheckDescriptor(const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor)
66 {
67 if (batchToSpaceNdDescriptor.m_BlockShape.size() == m_VisitorDescriptor.m_BlockShape.size())
68 {
69 for (unsigned int i = 0; i < batchToSpaceNdDescriptor.m_BlockShape.size(); ++i)
70 {
71 BOOST_CHECK_EQUAL(batchToSpaceNdDescriptor.m_BlockShape[i], m_VisitorDescriptor.m_BlockShape[i]);
72 }
73 }
74 else
75 {
76 BOOST_ERROR("Unequal vector size for batchToSpaceNdDescriptor m_BlockShape.");
77 }
78
79 if (batchToSpaceNdDescriptor.m_Crops.size() == m_VisitorDescriptor.m_Crops.size())
80 {
81 for (unsigned int i = 0; i < batchToSpaceNdDescriptor.m_Crops.size(); ++i)
82 {
83 BOOST_CHECK_EQUAL(batchToSpaceNdDescriptor.m_Crops[i].first, m_VisitorDescriptor.m_Crops[i].first);
84 BOOST_CHECK_EQUAL(batchToSpaceNdDescriptor.m_Crops[i].second, m_VisitorDescriptor.m_Crops[i].second);
85 }
86 }
87 else
88 {
89 BOOST_ERROR("Unequal vector size for batchToSpaceNdDescriptor m_Crops.");
90 }
91
92 BOOST_CHECK(batchToSpaceNdDescriptor.m_DataLayout == m_VisitorDescriptor.m_DataLayout);
93 }
94
95 void VisitBatchToSpaceNdLayer(const IConnectableLayer* layer,
96 const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
97 const char* name = nullptr) override
98 {
99 CheckLayerPointer(layer);
100 CheckDescriptor(batchToSpaceNdDescriptor);
101 CheckLayerName(name);
102 };
103};
104
105class TestPooling2dLayerVisitor : public TestLayerVisitor
106{
107private:
108 Pooling2dDescriptor m_VisitorDescriptor;
109
110public:
111 explicit TestPooling2dLayerVisitor(const Pooling2dDescriptor& pooling2dDescriptor, const char* name = nullptr)
112 : TestLayerVisitor(name)
113 {
114 m_VisitorDescriptor.m_PoolType = pooling2dDescriptor.m_PoolType;
115 m_VisitorDescriptor.m_PadLeft = pooling2dDescriptor.m_PadLeft;
116 m_VisitorDescriptor.m_PadRight = pooling2dDescriptor.m_PadRight;
117 m_VisitorDescriptor.m_PadTop = pooling2dDescriptor.m_PadTop;
118 m_VisitorDescriptor.m_PadBottom = pooling2dDescriptor.m_PadBottom;
119 m_VisitorDescriptor.m_PoolWidth = pooling2dDescriptor.m_PoolWidth;
120 m_VisitorDescriptor.m_PoolHeight = pooling2dDescriptor.m_PoolHeight;
121 m_VisitorDescriptor.m_StrideX = pooling2dDescriptor.m_StrideX;
122 m_VisitorDescriptor.m_StrideY = pooling2dDescriptor.m_StrideY;
123 m_VisitorDescriptor.m_OutputShapeRounding = pooling2dDescriptor.m_OutputShapeRounding;
124 m_VisitorDescriptor.m_PaddingMethod = pooling2dDescriptor.m_PaddingMethod;
125 m_VisitorDescriptor.m_DataLayout = pooling2dDescriptor.m_DataLayout;
126 };
127
128 void CheckDescriptor(const Pooling2dDescriptor& pooling2dDescriptor)
129 {
130 BOOST_CHECK(pooling2dDescriptor.m_PoolType == m_VisitorDescriptor.m_PoolType);
131 BOOST_CHECK_EQUAL(pooling2dDescriptor.m_PadLeft, m_VisitorDescriptor.m_PadLeft);
132 BOOST_CHECK_EQUAL(pooling2dDescriptor.m_PadRight, m_VisitorDescriptor.m_PadRight);
133 BOOST_CHECK_EQUAL(pooling2dDescriptor.m_PadTop, m_VisitorDescriptor.m_PadTop);
134 BOOST_CHECK_EQUAL(pooling2dDescriptor.m_PadBottom, m_VisitorDescriptor.m_PadBottom);
135 BOOST_CHECK_EQUAL(pooling2dDescriptor.m_PoolWidth, m_VisitorDescriptor.m_PoolWidth);
136 BOOST_CHECK_EQUAL(pooling2dDescriptor.m_PoolHeight, m_VisitorDescriptor.m_PoolHeight);
137 BOOST_CHECK_EQUAL(pooling2dDescriptor.m_StrideX, m_VisitorDescriptor.m_StrideX);
138 BOOST_CHECK_EQUAL(pooling2dDescriptor.m_StrideY, m_VisitorDescriptor.m_StrideY);
139 BOOST_CHECK(pooling2dDescriptor.m_OutputShapeRounding == m_VisitorDescriptor.m_OutputShapeRounding);
140 BOOST_CHECK(pooling2dDescriptor.m_PaddingMethod == m_VisitorDescriptor.m_PaddingMethod);
141 BOOST_CHECK(pooling2dDescriptor.m_DataLayout == m_VisitorDescriptor.m_DataLayout);
142 }
143
144 void VisitPooling2dLayer(const IConnectableLayer* layer,
145 const Pooling2dDescriptor& pooling2dDescriptor,
146 const char* name = nullptr) override
147 {
148 CheckLayerPointer(layer);
149 CheckDescriptor(pooling2dDescriptor);
150 CheckLayerName(name);
151 };
152};
153
154class TestActivationLayerVisitor : public TestLayerVisitor
155{
156private:
157 ActivationDescriptor m_VisitorDescriptor;
158
159public:
160 explicit TestActivationLayerVisitor(const ActivationDescriptor& activationDescriptor, const char* name = nullptr)
161 : TestLayerVisitor(name)
162 {
163 m_VisitorDescriptor.m_Function = activationDescriptor.m_Function;
164 m_VisitorDescriptor.m_A = activationDescriptor.m_A;
165 m_VisitorDescriptor.m_B = activationDescriptor.m_B;
166 };
167
168 void CheckDescriptor(const ActivationDescriptor& activationDescriptor)
169 {
170 BOOST_CHECK(activationDescriptor.m_Function == m_VisitorDescriptor.m_Function);
171 BOOST_CHECK_EQUAL(activationDescriptor.m_A, m_VisitorDescriptor.m_A);
172 BOOST_CHECK_EQUAL(activationDescriptor.m_B, m_VisitorDescriptor.m_B);
173 };
174
175 void VisitActivationLayer(const IConnectableLayer* layer,
176 const ActivationDescriptor& activationDescriptor,
177 const char* name = nullptr) override
178 {
179 CheckLayerPointer(layer);
180 CheckDescriptor(activationDescriptor);
181 CheckLayerName(name);
182 };
183};
184
185class TestNormalizationLayerVisitor : public TestLayerVisitor
186{
187private:
188 NormalizationDescriptor m_VisitorDescriptor;
189
190public:
191 explicit TestNormalizationLayerVisitor(const NormalizationDescriptor& normalizationDescriptor,
192 const char* name = nullptr)
193 : TestLayerVisitor(name)
194 {
195 m_VisitorDescriptor.m_NormChannelType = normalizationDescriptor.m_NormChannelType;
196 m_VisitorDescriptor.m_NormMethodType = normalizationDescriptor.m_NormMethodType;
197 m_VisitorDescriptor.m_NormSize = normalizationDescriptor.m_NormSize;
198 m_VisitorDescriptor.m_Alpha = normalizationDescriptor.m_Alpha;
199 m_VisitorDescriptor.m_Beta = normalizationDescriptor.m_Beta;
200 m_VisitorDescriptor.m_K = normalizationDescriptor.m_K;
201 m_VisitorDescriptor.m_DataLayout = normalizationDescriptor.m_DataLayout;
202 };
203
204 void CheckDescriptor(const NormalizationDescriptor& normalizationDescriptor)
205 {
206 BOOST_CHECK(normalizationDescriptor.m_NormChannelType == m_VisitorDescriptor.m_NormChannelType);
207 BOOST_CHECK(normalizationDescriptor.m_NormMethodType == m_VisitorDescriptor.m_NormMethodType);
208 BOOST_CHECK_EQUAL(normalizationDescriptor.m_NormSize, m_VisitorDescriptor.m_NormSize);
209 BOOST_CHECK_EQUAL(normalizationDescriptor.m_Alpha, m_VisitorDescriptor.m_Alpha);
210 BOOST_CHECK_EQUAL(normalizationDescriptor.m_Beta, m_VisitorDescriptor.m_Beta);
211 BOOST_CHECK_EQUAL(normalizationDescriptor.m_K, m_VisitorDescriptor.m_K);
212 BOOST_CHECK(normalizationDescriptor.m_DataLayout == m_VisitorDescriptor.m_DataLayout);
213 }
214
215 void VisitNormalizationLayer(const IConnectableLayer* layer,
216 const NormalizationDescriptor& normalizationDescriptor,
217 const char* name = nullptr) override
218 {
219 CheckLayerPointer(layer);
220 CheckDescriptor(normalizationDescriptor);
221 CheckLayerName(name);
222 };
223};
224
225class TestSoftmaxLayerVisitor : public TestLayerVisitor
226{
227private:
228 SoftmaxDescriptor m_VisitorDescriptor;
229
230public:
231 explicit TestSoftmaxLayerVisitor(const SoftmaxDescriptor& softmaxDescriptor, const char* name = nullptr)
232 : TestLayerVisitor(name)
233 {
234 m_VisitorDescriptor.m_Beta = softmaxDescriptor.m_Beta;
235 };
236
237 void CheckDescriptor(const SoftmaxDescriptor& softmaxDescriptor)
238 {
239 BOOST_CHECK_EQUAL(softmaxDescriptor.m_Beta, m_VisitorDescriptor.m_Beta);
240 }
241
242 void VisitSoftmaxLayer(const IConnectableLayer* layer,
243 const SoftmaxDescriptor& softmaxDescriptor,
244 const char* name = nullptr) override
245 {
246 CheckLayerPointer(layer);
247 CheckDescriptor(softmaxDescriptor);
248 CheckLayerName(name);
249 };
250};
251
252class TestSplitterLayerVisitor : public TestLayerVisitor
253{
254private:
255 ViewsDescriptor m_VisitorDescriptor;
256
257public:
258 explicit TestSplitterLayerVisitor(const ViewsDescriptor& splitterDescriptor, const char* name = nullptr)
259 : TestLayerVisitor(name)
260 , m_VisitorDescriptor(splitterDescriptor.GetNumViews(), splitterDescriptor.GetNumDimensions())
261 {
262 if (splitterDescriptor.GetNumViews() != m_VisitorDescriptor.GetNumViews())
263 {
264 BOOST_ERROR("Unequal number of views in splitter descriptor.");
265 }
266 else if (splitterDescriptor.GetNumDimensions() != m_VisitorDescriptor.GetNumDimensions())
267 {
268 BOOST_ERROR("Unequal number of dimensions in splitter descriptor.");
269 }
270 else
271 {
272 for (unsigned int i = 0; i < splitterDescriptor.GetNumViews(); ++i)
273 {
274 for (unsigned int j = 0; j < splitterDescriptor.GetNumDimensions(); ++j)
275 {
276 m_VisitorDescriptor.SetViewOriginCoord(i, j, splitterDescriptor.GetViewOrigin(i)[j]);
277 m_VisitorDescriptor.SetViewSize(i, j, splitterDescriptor.GetViewSizes(i)[j]);
278 }
279 }
280 }
281 };
282
283 void CheckDescriptor(const ViewsDescriptor& splitterDescriptor)
284 {
285
286 BOOST_CHECK_EQUAL(splitterDescriptor.GetNumViews(), m_VisitorDescriptor.GetNumViews());
287 BOOST_CHECK_EQUAL(splitterDescriptor.GetNumDimensions(), m_VisitorDescriptor.GetNumDimensions());
288
289 if (splitterDescriptor.GetNumViews() != m_VisitorDescriptor.GetNumViews())
290 {
291 BOOST_ERROR("Unequal number of views in splitter descriptor.");
292 }
293 else if (splitterDescriptor.GetNumDimensions() != m_VisitorDescriptor.GetNumDimensions())
294 {
295 BOOST_ERROR("Unequal number of dimensions in splitter descriptor.");
296 }
297 else
298 {
299 for (unsigned int i = 0; i < splitterDescriptor.GetNumViews(); ++i)
300 {
301 for (unsigned int j = 0; j < splitterDescriptor.GetNumDimensions(); ++j)
302 {
303 BOOST_CHECK_EQUAL(splitterDescriptor.GetViewOrigin(i)[j], m_VisitorDescriptor.GetViewOrigin(i)[j]);
304 BOOST_CHECK_EQUAL(splitterDescriptor.GetViewSizes(i)[j], m_VisitorDescriptor.GetViewSizes(i)[j]);
305 }
306 }
307 }
308 };
309
310 void VisitSplitterLayer(const IConnectableLayer* layer,
311 const ViewsDescriptor& splitterDescriptor,
312 const char* name = nullptr) override
313 {
314 CheckLayerPointer(layer);
315 CheckDescriptor(splitterDescriptor);
316 CheckLayerName(name);
317 };
318};
319
Jim Flynne242f2d2019-05-22 14:24:13 +0100320class TestConcatLayerVisitor : public TestLayerVisitor
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000321{
322private:
323 OriginsDescriptor m_VisitorDescriptor;
324
325public:
Jim Flynne242f2d2019-05-22 14:24:13 +0100326 explicit TestConcatLayerVisitor(const OriginsDescriptor& concatDescriptor, const char* name = nullptr)
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000327 : TestLayerVisitor(name)
Jim Flynne242f2d2019-05-22 14:24:13 +0100328 , m_VisitorDescriptor(concatDescriptor.GetNumViews(), concatDescriptor.GetNumDimensions())
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000329 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100330 m_VisitorDescriptor.SetConcatAxis(concatDescriptor.GetConcatAxis());
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000331
Jim Flynne242f2d2019-05-22 14:24:13 +0100332 if (concatDescriptor.GetNumViews() != m_VisitorDescriptor.GetNumViews())
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000333 {
334 BOOST_ERROR("Unequal number of views in splitter descriptor.");
335 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100336 else if (concatDescriptor.GetNumDimensions() != m_VisitorDescriptor.GetNumDimensions())
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000337 {
338 BOOST_ERROR("Unequal number of dimensions in splitter descriptor.");
339 }
340 else
341 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100342 for (unsigned int i = 0; i < concatDescriptor.GetNumViews(); ++i)
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000343 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100344 for (unsigned int j = 0; j < concatDescriptor.GetNumDimensions(); ++j)
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000345 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100346 m_VisitorDescriptor.SetViewOriginCoord(i, j, concatDescriptor.GetViewOrigin(i)[j]);
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000347 }
348 }
349 }
350 };
351
Jim Flynne242f2d2019-05-22 14:24:13 +0100352 void CheckDescriptor(const OriginsDescriptor& concatDescriptor)
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000353 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100354 BOOST_CHECK_EQUAL(concatDescriptor.GetNumViews(), m_VisitorDescriptor.GetNumViews());
355 BOOST_CHECK_EQUAL(concatDescriptor.GetNumDimensions(), m_VisitorDescriptor.GetNumDimensions());
356 BOOST_CHECK_EQUAL(concatDescriptor.GetConcatAxis(), m_VisitorDescriptor.GetConcatAxis());
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000357
Jim Flynne242f2d2019-05-22 14:24:13 +0100358 if (concatDescriptor.GetNumViews() != m_VisitorDescriptor.GetNumViews())
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000359 {
360 BOOST_ERROR("Unequal number of views in splitter descriptor.");
361 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100362 else if (concatDescriptor.GetNumDimensions() != m_VisitorDescriptor.GetNumDimensions())
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000363 {
364 BOOST_ERROR("Unequal number of dimensions in splitter descriptor.");
365 }
366 else
367 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100368 for (unsigned int i = 0; i < concatDescriptor.GetNumViews(); ++i)
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000369 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100370 for (unsigned int j = 0; j < concatDescriptor.GetNumDimensions(); ++j)
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000371 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100372 BOOST_CHECK_EQUAL(concatDescriptor.GetViewOrigin(i)[j], m_VisitorDescriptor.GetViewOrigin(i)[j]);
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000373 }
374 }
375 }
376 }
377
Jim Flynne242f2d2019-05-22 14:24:13 +0100378 void VisitConcatLayer(const IConnectableLayer* layer,
379 const OriginsDescriptor& concatDescriptor,
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000380 const char* name = nullptr) override
381 {
382 CheckLayerPointer(layer);
Jim Flynne242f2d2019-05-22 14:24:13 +0100383 CheckDescriptor(concatDescriptor);
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000384 CheckLayerName(name);
385 };
386};
387
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100388class TestResizeLayerVisitor : public TestLayerVisitor
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000389{
390private:
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100391 ResizeDescriptor m_VisitorDescriptor;
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000392
393public:
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100394 explicit TestResizeLayerVisitor(const ResizeDescriptor& descriptor, const char* name = nullptr)
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000395 : TestLayerVisitor(name)
396 {
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100397 m_VisitorDescriptor.m_Method = descriptor.m_Method;
398 m_VisitorDescriptor.m_TargetWidth = descriptor.m_TargetWidth;
399 m_VisitorDescriptor.m_TargetHeight = descriptor.m_TargetHeight;
400 m_VisitorDescriptor.m_DataLayout = descriptor.m_DataLayout;
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000401 };
402
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100403 void CheckDescriptor(const ResizeDescriptor& descriptor)
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000404 {
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100405 BOOST_CHECK(descriptor.m_Method == m_VisitorDescriptor.m_Method);
406 BOOST_CHECK(descriptor.m_TargetWidth == m_VisitorDescriptor.m_TargetWidth);
407 BOOST_CHECK(descriptor.m_TargetHeight == m_VisitorDescriptor.m_TargetHeight);
408 BOOST_CHECK(descriptor.m_DataLayout == m_VisitorDescriptor.m_DataLayout);
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000409 }
410
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100411 void VisitResizeLayer(const IConnectableLayer* layer,
412 const ResizeDescriptor& descriptor,
413 const char* name = nullptr) override
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000414 {
415 CheckLayerPointer(layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100416 CheckDescriptor(descriptor);
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000417 CheckLayerName(name);
418 };
419};
420
Kevin Mayce5045a2019-10-02 14:07:47 +0100421class TestInstanceNormalizationLayerVisitor : public TestLayerVisitor
422{
423private:
424 InstanceNormalizationDescriptor m_VisitorDescriptor;
425
426public:
427 explicit TestInstanceNormalizationLayerVisitor(const InstanceNormalizationDescriptor& desc,
428 const char* name = nullptr)
429 : TestLayerVisitor(name)
430 {
431 m_VisitorDescriptor.m_Beta = desc.m_Beta;
432 m_VisitorDescriptor.m_Gamma = desc.m_Gamma;
433 m_VisitorDescriptor.m_Eps = desc.m_Eps;
434 m_VisitorDescriptor.m_DataLayout = desc.m_DataLayout;
435 };
436
437 void CheckDescriptor(const InstanceNormalizationDescriptor& desc)
438 {
439 BOOST_CHECK(desc.m_Beta == m_VisitorDescriptor.m_Beta);
440 BOOST_CHECK(desc.m_Gamma == m_VisitorDescriptor.m_Gamma);
441 BOOST_CHECK(desc.m_Eps == m_VisitorDescriptor.m_Eps);
442 BOOST_CHECK(desc.m_DataLayout == m_VisitorDescriptor.m_DataLayout);
443 }
444
445 void VisitInstanceNormalizationLayer(const IConnectableLayer* layer,
446 const InstanceNormalizationDescriptor& desc,
447 const char* name = nullptr) override
448 {
449 CheckLayerPointer(layer);
450 CheckDescriptor(desc);
451 CheckLayerName(name);
452 };
453};
454
FrancisMurtaghf08876f2019-02-04 15:41:17 +0000455class TestL2NormalizationLayerVisitor : public TestLayerVisitor
456{
457private:
458 L2NormalizationDescriptor m_VisitorDescriptor;
459
460public:
461 explicit TestL2NormalizationLayerVisitor(const L2NormalizationDescriptor& desc, const char* name = nullptr)
462 : TestLayerVisitor(name)
463 {
464 m_VisitorDescriptor.m_DataLayout = desc.m_DataLayout;
465 };
466
467 void CheckDescriptor(const L2NormalizationDescriptor& desc)
468 {
469 BOOST_CHECK(desc.m_DataLayout == m_VisitorDescriptor.m_DataLayout);
470 }
471
472 void VisitL2NormalizationLayer(const IConnectableLayer* layer,
473 const L2NormalizationDescriptor& desc,
474 const char* name = nullptr) override
475 {
476 CheckLayerPointer(layer);
477 CheckDescriptor(desc);
478 CheckLayerName(name);
479 };
480};
481
482class TestReshapeLayerVisitor : public TestLayerVisitor
483{
484private:
485 const ReshapeDescriptor m_VisitorDescriptor;
486
487public:
488 explicit TestReshapeLayerVisitor(const ReshapeDescriptor& reshapeDescriptor, const char* name = nullptr)
489 : TestLayerVisitor(name)
490 , m_VisitorDescriptor(reshapeDescriptor.m_TargetShape)
491 {};
492
493 void CheckDescriptor(const ReshapeDescriptor& reshapeDescriptor)
494 {
495 BOOST_CHECK_MESSAGE(reshapeDescriptor.m_TargetShape == m_VisitorDescriptor.m_TargetShape,
496 reshapeDescriptor.m_TargetShape << " compared to " << m_VisitorDescriptor.m_TargetShape);
497 }
498
499 void VisitReshapeLayer(const IConnectableLayer* layer,
500 const ReshapeDescriptor& reshapeDescriptor,
501 const char* name = nullptr) override
502 {
503 CheckLayerPointer(layer);
504 CheckDescriptor(reshapeDescriptor);
505 CheckLayerName(name);
506 };
507};
508
509class TestSpaceToBatchNdLayerVisitor : public TestLayerVisitor
510{
511private:
512 SpaceToBatchNdDescriptor m_VisitorDescriptor;
513
514public:
515 explicit TestSpaceToBatchNdLayerVisitor(const SpaceToBatchNdDescriptor& desc, const char* name = nullptr)
516 : TestLayerVisitor(name)
517 , m_VisitorDescriptor(desc.m_BlockShape, desc.m_PadList)
518 {
519 m_VisitorDescriptor.m_DataLayout = desc.m_DataLayout;
520 };
521
522 void CheckDescriptor(const SpaceToBatchNdDescriptor& desc)
523 {
524 if (desc.m_BlockShape.size() == m_VisitorDescriptor.m_BlockShape.size())
525 {
526 for (unsigned int i = 0; i < desc.m_BlockShape.size(); ++i)
527 {
528 BOOST_CHECK_EQUAL(desc.m_BlockShape[i], m_VisitorDescriptor.m_BlockShape[i]);
529 }
530 }
531 else
532 {
533 BOOST_ERROR("Unequal vector size for SpaceToBatchNdDescriptor m_BlockShape.");
534 }
535
536 if (desc.m_PadList.size() == m_VisitorDescriptor.m_PadList.size())
537 {
538 for (unsigned int i = 0; i < desc.m_PadList.size(); ++i)
539 {
540 BOOST_CHECK_EQUAL(desc.m_PadList[i].first, m_VisitorDescriptor.m_PadList[i].first);
541 BOOST_CHECK_EQUAL(desc.m_PadList[i].second, m_VisitorDescriptor.m_PadList[i].second);
542 }
543 }
544 else
545 {
546 BOOST_ERROR("Unequal vector size for SpaceToBatchNdDescriptor m_PadList.");
547 }
548
549 BOOST_CHECK(desc.m_DataLayout == m_VisitorDescriptor.m_DataLayout);
550 }
551
552 void VisitSpaceToBatchNdLayer(const IConnectableLayer* layer,
553 const SpaceToBatchNdDescriptor& desc,
554 const char* name = nullptr) override
555 {
556 CheckLayerPointer(layer);
557 CheckDescriptor(desc);
558 CheckLayerName(name);
559 };
560};
561
562class TestMeanLayerVisitor : public TestLayerVisitor
563{
564private:
565 const MeanDescriptor m_VisitorDescriptor;
566
567public:
568 explicit TestMeanLayerVisitor(const MeanDescriptor& meanDescriptor, const char* name = nullptr)
569 : TestLayerVisitor(name)
570 , m_VisitorDescriptor(meanDescriptor.m_Axis, meanDescriptor.m_KeepDims)
571 {};
572
573 void CheckDescriptor(const MeanDescriptor& meanDescriptor)
574 {
575 if (meanDescriptor.m_Axis.size() == m_VisitorDescriptor.m_Axis.size())
576 {
577 for (unsigned int i = 0; i < meanDescriptor.m_Axis.size(); ++i)
578 {
579 BOOST_CHECK_EQUAL(meanDescriptor.m_Axis[i], m_VisitorDescriptor.m_Axis[i]);
580 }
581 }
582 else
583 {
584 BOOST_ERROR("Unequal vector size for MeanDescriptor m_Axis.");
585 }
586
587 BOOST_CHECK_EQUAL(meanDescriptor.m_KeepDims, m_VisitorDescriptor.m_KeepDims);
588 }
589
590 void VisitMeanLayer(const IConnectableLayer* layer,
591 const MeanDescriptor& meanDescriptor,
592 const char* name = nullptr) override
593 {
594 CheckLayerPointer(layer);
595 CheckDescriptor(meanDescriptor);
596 CheckLayerName(name);
597 };
598};
599
600class TestPadLayerVisitor : public TestLayerVisitor
601{
602private:
603 const PadDescriptor m_VisitorDescriptor;
604
605public:
606 explicit TestPadLayerVisitor(const PadDescriptor& padDescriptor, const char* name = nullptr)
607 : TestLayerVisitor(name)
608 , m_VisitorDescriptor(padDescriptor.m_PadList)
609 {};
610
611 void CheckDescriptor(const PadDescriptor& padDescriptor)
612 {
613 if (padDescriptor.m_PadList.size() == m_VisitorDescriptor.m_PadList.size())
614 {
615 for (unsigned int i = 0; i < padDescriptor.m_PadList.size(); ++i)
616 {
617 BOOST_CHECK_EQUAL(padDescriptor.m_PadList[i].first, m_VisitorDescriptor.m_PadList[i].first);
618 BOOST_CHECK_EQUAL(padDescriptor.m_PadList[i].second, m_VisitorDescriptor.m_PadList[i].second);
619 }
620 }
621 else
622 {
623 BOOST_ERROR("Unequal vector size for SpaceToBatchNdDescriptor m_PadList.");
624 }
625 }
626
627 void VisitPadLayer(const IConnectableLayer* layer,
628 const PadDescriptor& padDescriptor,
629 const char* name = nullptr) override
630 {
631 CheckLayerPointer(layer);
632 CheckDescriptor(padDescriptor);
633 CheckLayerName(name);
634 };
635};
636
637class TestStridedSliceLayerVisitor : public TestLayerVisitor
638{
639private:
640 StridedSliceDescriptor m_VisitorDescriptor;
641
642public:
643 explicit TestStridedSliceLayerVisitor(const StridedSliceDescriptor& desc, const char* name = nullptr)
644 : TestLayerVisitor(name)
645 , m_VisitorDescriptor(desc.m_Begin, desc.m_End, desc.m_Stride)
646 {
647 m_VisitorDescriptor.m_BeginMask = desc.m_BeginMask;
648 m_VisitorDescriptor.m_EndMask = desc.m_EndMask;
649 m_VisitorDescriptor.m_ShrinkAxisMask = desc.m_ShrinkAxisMask;
650 m_VisitorDescriptor.m_EllipsisMask = desc.m_EllipsisMask;
651 m_VisitorDescriptor.m_NewAxisMask = desc.m_NewAxisMask;
652 m_VisitorDescriptor.m_DataLayout = desc.m_DataLayout;
653 };
654
655 void CheckDescriptor(const StridedSliceDescriptor& desc)
656 {
657 if (desc.m_Begin.size() == m_VisitorDescriptor.m_Begin.size())
658 {
659 for (unsigned int i = 0; i < desc.m_Begin.size(); ++i)
660 {
661 BOOST_CHECK_EQUAL(desc.m_Begin[i], m_VisitorDescriptor.m_Begin[i]);
662 }
663 }
664 else
665 {
666 BOOST_ERROR("Unequal vector size for StridedSliceDescriptor m_Begin.");
667 }
668
669 if (desc.m_End.size() == m_VisitorDescriptor.m_End.size())
670 {
671 for (unsigned int i = 0; i < desc.m_End.size(); ++i)
672 {
673 BOOST_CHECK_EQUAL(desc.m_End[i], m_VisitorDescriptor.m_End[i]);
674 }
675 }
676 else
677 {
678 BOOST_ERROR("Unequal vector size for StridedSliceDescriptor m_End.");
679 }
680
681 if (desc.m_Stride.size() == m_VisitorDescriptor.m_Stride.size())
682 {
683 for (unsigned int i = 0; i < desc.m_Stride.size(); ++i)
684 {
685 BOOST_CHECK_EQUAL(desc.m_Stride[i], m_VisitorDescriptor.m_Stride[i]);
686 }
687 }
688 else
689 {
690 BOOST_ERROR("Unequal vector size for StridedSliceDescriptor m_Stride.");
691 }
692
693 BOOST_CHECK_EQUAL(desc.m_BeginMask, m_VisitorDescriptor.m_BeginMask);
694 BOOST_CHECK_EQUAL(desc.m_EndMask, m_VisitorDescriptor.m_EndMask);
695 BOOST_CHECK_EQUAL(desc.m_ShrinkAxisMask, m_VisitorDescriptor.m_ShrinkAxisMask);
696 BOOST_CHECK_EQUAL(desc.m_EllipsisMask, m_VisitorDescriptor.m_EllipsisMask);
697 BOOST_CHECK_EQUAL(desc.m_NewAxisMask, m_VisitorDescriptor.m_NewAxisMask);
698 BOOST_CHECK(desc.m_DataLayout == m_VisitorDescriptor.m_DataLayout);
699 }
700
701 void VisitStridedSliceLayer(const IConnectableLayer* layer,
702 const StridedSliceDescriptor& desc,
703 const char* name = nullptr) override
704 {
705 CheckLayerPointer(layer);
706 CheckDescriptor(desc);
707 CheckLayerName(name);
708 };
709};
710
711} //namespace armnn