blob: 8969344d6dcbe4decd0e2240f40dbaecd17feb24 [file] [log] [blame]
Teresa Charlin18147332021-11-17 14:34:30 +00001# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
Richard Burtondc0c6ed2020-04-08 16:39:05 +01002# SPDX-License-Identifier: MIT
3import inspect
4
5import pytest
6
7import pyarmnn as ann
8import numpy as np
9import pyarmnn._generated.pyarmnn as generated
10
11
12def test_activation_descriptor_default_values():
13 desc = ann.ActivationDescriptor()
14 assert desc.m_Function == ann.ActivationFunction_Sigmoid
15 assert desc.m_A == 0
16 assert desc.m_B == 0
17
18
19def test_argminmax_descriptor_default_values():
20 desc = ann.ArgMinMaxDescriptor()
21 assert desc.m_Function == ann.ArgMinMaxFunction_Min
22 assert desc.m_Axis == -1
23
24
25def test_batchnormalization_descriptor_default_values():
26 desc = ann.BatchNormalizationDescriptor()
27 assert desc.m_DataLayout == ann.DataLayout_NCHW
28 np.allclose(0.0001, desc.m_Eps)
29
30
31def test_batchtospacend_descriptor_default_values():
32 desc = ann.BatchToSpaceNdDescriptor()
33 assert desc.m_DataLayout == ann.DataLayout_NCHW
34 assert [1, 1] == desc.m_BlockShape
35 assert [(0, 0), (0, 0)] == desc.m_Crops
36
37
38def test_batchtospacend_descriptor_assignment():
39 desc = ann.BatchToSpaceNdDescriptor()
40 desc.m_BlockShape = (1, 2, 3)
41
42 ololo = [(1, 2), (3, 4)]
43 size_1 = len(ololo)
44 desc.m_Crops = ololo
45
46 assert size_1 == len(ololo)
47 desc.m_DataLayout = ann.DataLayout_NHWC
48 assert ann.DataLayout_NHWC == desc.m_DataLayout
49 assert [1, 2, 3] == desc.m_BlockShape
50 assert [(1, 2), (3, 4)] == desc.m_Crops
51
52
53@pytest.mark.parametrize("input_shape, value, vtype", [([-1], -1, 'int'), (("one", "two"), "'one'", 'str'),
54 ([1.33, 4.55], 1.33, 'float'),
55 ([{1: "one"}], "{1: 'one'}", 'dict')], ids=lambda x: str(x))
56def test_batchtospacend_descriptor_rubbish_assignment_shape(input_shape, value, vtype):
57 desc = ann.BatchToSpaceNdDescriptor()
58 with pytest.raises(TypeError) as err:
59 desc.m_BlockShape = input_shape
60
61 assert "Failed to convert python input value {} of type '{}' to C type 'j'".format(value, vtype) in str(err.value)
62
63
64@pytest.mark.parametrize("input_crops, value, vtype", [([(1, 2), (3, 4, 5)], '(3, 4, 5)', 'tuple'),
65 ([(1, 'one')], "(1, 'one')", 'tuple'),
66 ([-1], -1, 'int'),
67 ([(1, (1, 2))], '(1, (1, 2))', 'tuple'),
68 ([[1, [1, 2]]], '[1, [1, 2]]', 'list')
69 ], ids=lambda x: str(x))
70def test_batchtospacend_descriptor_rubbish_assignment_crops(input_crops, value, vtype):
71 desc = ann.BatchToSpaceNdDescriptor()
72 with pytest.raises(TypeError) as err:
73 desc.m_Crops = input_crops
74
75 assert "Failed to convert python input value {} of type '{}' to C type".format(value, vtype) in str(err.value)
76
77
78def test_batchtospacend_descriptor_empty_assignment():
79 desc = ann.BatchToSpaceNdDescriptor()
80 desc.m_BlockShape = []
81 assert [] == desc.m_BlockShape
82
83
84def test_batchtospacend_descriptor_ctor():
85 desc = ann.BatchToSpaceNdDescriptor([1, 2, 3], [(4, 5), (6, 7)])
86 assert desc.m_DataLayout == ann.DataLayout_NCHW
87 assert [1, 2, 3] == desc.m_BlockShape
88 assert [(4, 5), (6, 7)] == desc.m_Crops
89
90
Teresa Charlinf7b50112021-11-18 15:24:50 +000091def test_channelshuffle_descriptor_default_values():
92 desc = ann.ChannelShuffleDescriptor()
93 assert desc.m_Axis == 0
94 assert desc.m_NumGroups == 0
95
Richard Burtondc0c6ed2020-04-08 16:39:05 +010096def test_convolution2d_descriptor_default_values():
97 desc = ann.Convolution2dDescriptor()
98 assert desc.m_PadLeft == 0
99 assert desc.m_PadTop == 0
100 assert desc.m_PadRight == 0
101 assert desc.m_PadBottom == 0
Teresa Charlinf2ed1b82020-11-24 15:11:54 +0000102 assert desc.m_StrideX == 1
103 assert desc.m_StrideY == 1
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100104 assert desc.m_DilationX == 1
105 assert desc.m_DilationY == 1
106 assert desc.m_BiasEnabled == False
107 assert desc.m_DataLayout == ann.DataLayout_NCHW
108
Teresa Charlin18147332021-11-17 14:34:30 +0000109def test_convolution3d_descriptor_default_values():
110 desc = ann.Convolution3dDescriptor()
111 assert desc.m_PadLeft == 0
112 assert desc.m_PadTop == 0
113 assert desc.m_PadRight == 0
114 assert desc.m_PadBottom == 0
115 assert desc.m_PadFront == 0
116 assert desc.m_PadBack == 0
117 assert desc.m_StrideX == 1
118 assert desc.m_StrideY == 1
119 assert desc.m_StrideZ == 1
120 assert desc.m_DilationX == 1
121 assert desc.m_DilationY == 1
122 assert desc.m_DilationZ == 1
123 assert desc.m_BiasEnabled == False
124 assert desc.m_DataLayout == ann.DataLayout_NDHWC
125
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100126
127def test_depthtospace_descriptor_default_values():
128 desc = ann.DepthToSpaceDescriptor()
129 assert desc.m_BlockSize == 1
130 assert desc.m_DataLayout == ann.DataLayout_NHWC
131
132
133def test_depthwise_convolution2d_descriptor_default_values():
134 desc = ann.DepthwiseConvolution2dDescriptor()
135 assert desc.m_PadLeft == 0
136 assert desc.m_PadTop == 0
137 assert desc.m_PadRight == 0
138 assert desc.m_PadBottom == 0
Teresa Charlinf2ed1b82020-11-24 15:11:54 +0000139 assert desc.m_StrideX == 1
140 assert desc.m_StrideY == 1
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100141 assert desc.m_DilationX == 1
142 assert desc.m_DilationY == 1
143 assert desc.m_BiasEnabled == False
144 assert desc.m_DataLayout == ann.DataLayout_NCHW
145
146
147def test_detectionpostprocess_descriptor_default_values():
148 desc = ann.DetectionPostProcessDescriptor()
149 assert desc.m_MaxDetections == 0
150 assert desc.m_MaxClassesPerDetection == 1
151 assert desc.m_DetectionsPerClass == 1
152 assert desc.m_NmsScoreThreshold == 0
153 assert desc.m_NmsIouThreshold == 0
154 assert desc.m_NumClasses == 0
155 assert desc.m_UseRegularNms == False
156 assert desc.m_ScaleH == 0
157 assert desc.m_ScaleW == 0
158 assert desc.m_ScaleX == 0
159 assert desc.m_ScaleY == 0
160
161
162def test_fakequantization_descriptor_default_values():
163 desc = ann.FakeQuantizationDescriptor()
164 np.allclose(6, desc.m_Max)
165 np.allclose(-6, desc.m_Min)
166
167
Jan Eilers841aca12020-08-12 14:59:06 +0100168def test_fill_descriptor_default_values():
169 desc = ann.FillDescriptor()
170 np.allclose(0, desc.m_Value)
171
172
173def test_gather_descriptor_default_values():
174 desc = ann.GatherDescriptor()
175 assert desc.m_Axis == 0
176
177
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100178def test_fully_connected_descriptor_default_values():
179 desc = ann.FullyConnectedDescriptor()
180 assert desc.m_BiasEnabled == False
181 assert desc.m_TransposeWeightMatrix == False
182
183
184def test_instancenormalization_descriptor_default_values():
185 desc = ann.InstanceNormalizationDescriptor()
186 assert desc.m_Gamma == 1
187 assert desc.m_Beta == 0
188 assert desc.m_DataLayout == ann.DataLayout_NCHW
189 np.allclose(1e-12, desc.m_Eps)
190
191
192def test_lstm_descriptor_default_values():
193 desc = ann.LstmDescriptor()
194 assert desc.m_ActivationFunc == 1
195 assert desc.m_ClippingThresCell == 0
196 assert desc.m_ClippingThresProj == 0
197 assert desc.m_CifgEnabled == True
198 assert desc.m_PeepholeEnabled == False
199 assert desc.m_ProjectionEnabled == False
200 assert desc.m_LayerNormEnabled == False
201
202
203def test_l2normalization_descriptor_default_values():
204 desc = ann.L2NormalizationDescriptor()
205 assert desc.m_DataLayout == ann.DataLayout_NCHW
206 np.allclose(1e-12, desc.m_Eps)
207
208
209def test_mean_descriptor_default_values():
210 desc = ann.MeanDescriptor()
211 assert desc.m_KeepDims == False
212
213
214def test_normalization_descriptor_default_values():
215 desc = ann.NormalizationDescriptor()
216 assert desc.m_NormChannelType == ann.NormalizationAlgorithmChannel_Across
217 assert desc.m_NormMethodType == ann.NormalizationAlgorithmMethod_LocalBrightness
218 assert desc.m_NormSize == 0
219 assert desc.m_Alpha == 0
220 assert desc.m_Beta == 0
221 assert desc.m_K == 0
222 assert desc.m_DataLayout == ann.DataLayout_NCHW
223
224
225def test_origin_descriptor_default_values():
226 desc = ann.ConcatDescriptor()
227 assert 0 == desc.GetNumViews()
228 assert 0 == desc.GetNumDimensions()
229 assert 1 == desc.GetConcatAxis()
230
231
232def test_origin_descriptor_incorrect_views():
233 desc = ann.ConcatDescriptor(2, 2)
234 with pytest.raises(RuntimeError) as err:
235 desc.SetViewOriginCoord(1000, 100, 1000)
236 assert "Failed to set view origin coordinates." in str(err.value)
237
238
239def test_origin_descriptor_ctor():
240 desc = ann.ConcatDescriptor(2, 2)
241 value = 5
242 for i in range(desc.GetNumViews()):
243 for j in range(desc.GetNumDimensions()):
244 desc.SetViewOriginCoord(i, j, value+i)
245 desc.SetConcatAxis(1)
246
247 assert 2 == desc.GetNumViews()
248 assert 2 == desc.GetNumDimensions()
249 assert [5, 5] == desc.GetViewOrigin(0)
250 assert [6, 6] == desc.GetViewOrigin(1)
251 assert 1 == desc.GetConcatAxis()
252
253
254def test_pad_descriptor_default_values():
255 desc = ann.PadDescriptor()
256 assert desc.m_PadValue == 0
Teresa Charlincd3fdae2021-11-18 15:51:36 +0000257 assert desc.m_PaddingMode == ann.PaddingMode_Constant
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100258
259
260def test_permute_descriptor_default_values():
261 pv = ann.PermutationVector((0, 2, 3, 1))
262 desc = ann.PermuteDescriptor(pv)
263 assert desc.m_DimMappings.GetSize() == 4
264 assert desc.m_DimMappings[0] == 0
265 assert desc.m_DimMappings[1] == 2
266 assert desc.m_DimMappings[2] == 3
267 assert desc.m_DimMappings[3] == 1
268
269
270def test_pooling_descriptor_default_values():
271 desc = ann.Pooling2dDescriptor()
272 assert desc.m_PoolType == ann.PoolingAlgorithm_Max
273 assert desc.m_PadLeft == 0
274 assert desc.m_PadTop == 0
275 assert desc.m_PadRight == 0
276 assert desc.m_PadBottom == 0
277 assert desc.m_PoolHeight == 0
278 assert desc.m_PoolWidth == 0
279 assert desc.m_StrideX == 0
280 assert desc.m_StrideY == 0
281 assert desc.m_OutputShapeRounding == ann.OutputShapeRounding_Floor
282 assert desc.m_PaddingMethod == ann.PaddingMethod_Exclude
283 assert desc.m_DataLayout == ann.DataLayout_NCHW
284
Ryan OShea89655002022-03-09 02:07:24 +0000285def test_pooling_3d_descriptor_default_values():
286 desc = ann.Pooling3dDescriptor()
287 assert desc.m_PoolType == ann.PoolingAlgorithm_Max
288 assert desc.m_PadLeft == 0
289 assert desc.m_PadTop == 0
290 assert desc.m_PadRight == 0
291 assert desc.m_PadBottom == 0
292 assert desc.m_PadFront == 0
293 assert desc.m_PadBack == 0
294 assert desc.m_PoolHeight == 0
295 assert desc.m_PoolWidth == 0
296 assert desc.m_StrideX == 0
297 assert desc.m_StrideY == 0
298 assert desc.m_StrideZ == 0
299 assert desc.m_OutputShapeRounding == ann.OutputShapeRounding_Floor
300 assert desc.m_PaddingMethod == ann.PaddingMethod_Exclude
301 assert desc.m_DataLayout == ann.DataLayout_NCDHW
302
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100303
304def test_reshape_descriptor_default_values():
305 desc = ann.ReshapeDescriptor()
306 # check the empty Targetshape
307 assert desc.m_TargetShape.GetNumDimensions() == 0
308
Ryan OShea09a05222021-11-18 16:52:41 +0000309def test_reduce_descriptor_default_values():
310 desc = ann.ReduceDescriptor()
311 assert desc.m_KeepDims == False
312 assert desc.m_vAxis == []
313 assert desc.m_ReduceOperation == ann.ReduceOperation_Sum
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100314
315def test_slice_descriptor_default_values():
316 desc = ann.SliceDescriptor()
317 assert desc.m_TargetWidth == 0
318 assert desc.m_TargetHeight == 0
319 assert desc.m_Method == ann.ResizeMethod_NearestNeighbor
320 assert desc.m_DataLayout == ann.DataLayout_NCHW
321
322
323def test_resize_descriptor_default_values():
324 desc = ann.ResizeDescriptor()
325 assert desc.m_TargetWidth == 0
326 assert desc.m_TargetHeight == 0
327 assert desc.m_Method == ann.ResizeMethod_NearestNeighbor
328 assert desc.m_DataLayout == ann.DataLayout_NCHW
Kevin May1c6e9762020-06-03 16:05:00 +0100329 assert desc.m_AlignCorners == False
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100330
331
332def test_spacetobatchnd_descriptor_default_values():
333 desc = ann.SpaceToBatchNdDescriptor()
334 assert desc.m_DataLayout == ann.DataLayout_NCHW
335
336
337def test_spacetodepth_descriptor_default_values():
338 desc = ann.SpaceToDepthDescriptor()
339 assert desc.m_BlockSize == 1
340 assert desc.m_DataLayout == ann.DataLayout_NHWC
341
342
343def test_stack_descriptor_default_values():
344 desc = ann.StackDescriptor()
345 assert desc.m_Axis == 0
346 assert desc.m_NumInputs == 0
347 # check the empty Inputshape
348 assert desc.m_InputShape.GetNumDimensions() == 0
349
350
351def test_slice_descriptor_default_values():
352 desc = ann.SliceDescriptor()
353 desc.m_Begin = [1, 2, 3, 4, 5]
354 desc.m_Size = (1, 2, 3, 4)
355
356 assert [1, 2, 3, 4, 5] == desc.m_Begin
357 assert [1, 2, 3, 4] == desc.m_Size
358
359
360def test_slice_descriptor_ctor():
361 desc = ann.SliceDescriptor([1, 2, 3, 4, 5], (1, 2, 3, 4))
362
363 assert [1, 2, 3, 4, 5] == desc.m_Begin
364 assert [1, 2, 3, 4] == desc.m_Size
365
366
367def test_strided_slice_descriptor_default_values():
368 desc = ann.StridedSliceDescriptor()
369 desc.m_Begin = [1, 2, 3, 4, 5]
370 desc.m_End = [6, 7, 8, 9, 10]
371 desc.m_Stride = (10, 10)
372 desc.m_BeginMask = 1
373 desc.m_EndMask = 2
374 desc.m_ShrinkAxisMask = 3
375 desc.m_EllipsisMask = 4
376 desc.m_NewAxisMask = 5
377
378 assert [1, 2, 3, 4, 5] == desc.m_Begin
379 assert [6, 7, 8, 9, 10] == desc.m_End
380 assert [10, 10] == desc.m_Stride
381 assert 1 == desc.m_BeginMask
382 assert 2 == desc.m_EndMask
383 assert 3 == desc.m_ShrinkAxisMask
384 assert 4 == desc.m_EllipsisMask
385 assert 5 == desc.m_NewAxisMask
386
387
388def test_strided_slice_descriptor_ctor():
389 desc = ann.StridedSliceDescriptor([1, 2, 3, 4, 5], [6, 7, 8, 9, 10], (10, 10))
390 desc.m_Begin = [1, 2, 3, 4, 5]
391 desc.m_End = [6, 7, 8, 9, 10]
392 desc.m_Stride = (10, 10)
393
394 assert [1, 2, 3, 4, 5] == desc.m_Begin
395 assert [6, 7, 8, 9, 10] == desc.m_End
396 assert [10, 10] == desc.m_Stride
397
398
399def test_softmax_descriptor_default_values():
400 desc = ann.SoftmaxDescriptor()
401 assert desc.m_Axis == -1
402 np.allclose(1.0, desc.m_Beta)
403
404
405def test_space_to_batch_nd_descriptor_default_values():
406 desc = ann.SpaceToBatchNdDescriptor()
407 assert [1, 1] == desc.m_BlockShape
408 assert [(0, 0), (0, 0)] == desc.m_PadList
409 assert ann.DataLayout_NCHW == desc.m_DataLayout
410
411
412def test_space_to_batch_nd_descriptor_assigned_values():
413 desc = ann.SpaceToBatchNdDescriptor()
414 desc.m_BlockShape = (90, 100)
415 desc.m_PadList = [(1, 2), (3, 4)]
416 assert [90, 100] == desc.m_BlockShape
417 assert [(1, 2), (3, 4)] == desc.m_PadList
418 assert ann.DataLayout_NCHW == desc.m_DataLayout
419
420
421def test_space_to_batch_nd_descriptor_ctor():
422 desc = ann.SpaceToBatchNdDescriptor((1, 2, 3), [(1, 2), (3, 4)])
423 assert [1, 2, 3] == desc.m_BlockShape
424 assert [(1, 2), (3, 4)] == desc.m_PadList
425 assert ann.DataLayout_NCHW == desc.m_DataLayout
426
427
428def test_transpose_convolution2d_descriptor_default_values():
Jan Eilers841aca12020-08-12 14:59:06 +0100429 desc = ann.TransposeConvolution2dDescriptor()
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100430 assert desc.m_PadLeft == 0
431 assert desc.m_PadTop == 0
432 assert desc.m_PadRight == 0
433 assert desc.m_PadBottom == 0
434 assert desc.m_StrideX == 0
435 assert desc.m_StrideY == 0
436 assert desc.m_BiasEnabled == False
437 assert desc.m_DataLayout == ann.DataLayout_NCHW
Jan Eilers841aca12020-08-12 14:59:06 +0100438 assert desc.m_OutputShapeEnabled == False
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100439
Cathal Corbett2b4182f2021-11-18 10:28:47 +0000440def test_transpose_descriptor_default_values():
441 pv = ann.PermutationVector((0, 3, 2, 1, 4))
442 desc = ann.TransposeDescriptor(pv)
443 assert desc.m_DimMappings.GetSize() == 5
444 assert desc.m_DimMappings[0] == 0
445 assert desc.m_DimMappings[1] == 3
446 assert desc.m_DimMappings[2] == 2
447 assert desc.m_DimMappings[3] == 1
448 assert desc.m_DimMappings[4] == 4
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100449
450def test_view_descriptor_default_values():
451 desc = ann.SplitterDescriptor()
452 assert 0 == desc.GetNumViews()
453 assert 0 == desc.GetNumDimensions()
454
455
456def test_elementwise_unary_descriptor_default_values():
457 desc = ann.ElementwiseUnaryDescriptor()
458 assert desc.m_Operation == ann.UnaryOperation_Abs
459
460
Cathal Corbettf0836e02021-11-18 18:17:38 +0000461def test_logical_binary_descriptor_default_values():
462 desc = ann.LogicalBinaryDescriptor()
463 assert desc.m_Operation == ann.LogicalBinaryOperation_LogicalAnd
464
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100465def test_view_descriptor_incorrect_input():
466 desc = ann.SplitterDescriptor(2, 3)
467 with pytest.raises(RuntimeError) as err:
468 desc.SetViewOriginCoord(1000, 100, 1000)
469 assert "Failed to set view origin coordinates." in str(err.value)
470
471 with pytest.raises(RuntimeError) as err:
472 desc.SetViewSize(1000, 100, 1000)
473 assert "Failed to set view size." in str(err.value)
474
475
476def test_view_descriptor_ctor():
477 desc = ann.SplitterDescriptor(2, 3)
478 value_size = 1
479 value_orig_coord = 5
480 for i in range(desc.GetNumViews()):
481 for j in range(desc.GetNumDimensions()):
482 desc.SetViewOriginCoord(i, j, value_orig_coord+i)
483 desc.SetViewSize(i, j, value_size+i)
484
485 assert 2 == desc.GetNumViews()
486 assert 3 == desc.GetNumDimensions()
487 assert [5, 5] == desc.GetViewOrigin(0)
488 assert [6, 6] == desc.GetViewOrigin(1)
489 assert [1, 1] == desc.GetViewSizes(0)
490 assert [2, 2] == desc.GetViewSizes(1)
491
492
493def test_createdescriptorforconcatenation_ctor():
494 input_shape_vector = [ann.TensorShape((2, 1)), ann.TensorShape((3, 1)), ann.TensorShape((4, 1))]
495 desc = ann.CreateDescriptorForConcatenation(input_shape_vector, 0)
496 assert 3 == desc.GetNumViews()
497 assert 0 == desc.GetConcatAxis()
498 assert 2 == desc.GetNumDimensions()
499 c = desc.GetViewOrigin(1)
500 d = desc.GetViewOrigin(0)
501
502
503def test_createdescriptorforconcatenation_wrong_shape_for_axis():
504 input_shape_vector = [ann.TensorShape((1, 2)), ann.TensorShape((3, 4)), ann.TensorShape((5, 6))]
505 with pytest.raises(RuntimeError) as err:
506 desc = ann.CreateDescriptorForConcatenation(input_shape_vector, 0)
507
508 assert "All inputs to concatenation must be the same size along all dimensions except the concatenation dimension" in str(
509 err.value)
510
511
512@pytest.mark.parametrize("input_shape_vector", [([-1, "one"]),
513 ([1.33, 4.55]),
514 ([{1: "one"}])], ids=lambda x: str(x))
515def test_createdescriptorforconcatenation_rubbish_assignment_shape_vector(input_shape_vector):
516 with pytest.raises(TypeError) as err:
517 desc = ann.CreateDescriptorForConcatenation(input_shape_vector, 0)
518
519 assert "in method 'CreateDescriptorForConcatenation', argument 1 of type 'std::vector< armnn::TensorShape,std::allocator< armnn::TensorShape > >'" in str(
520 err.value)
521
522
523generated_classes = inspect.getmembers(generated, inspect.isclass)
524generated_classes_names = list(map(lambda x: x[0], generated_classes))
525@pytest.mark.parametrize("desc_name", ['ActivationDescriptor',
526 'ArgMinMaxDescriptor',
527 'PermuteDescriptor',
528 'SoftmaxDescriptor',
529 'ConcatDescriptor',
530 'SplitterDescriptor',
531 'Pooling2dDescriptor',
532 'FullyConnectedDescriptor',
533 'Convolution2dDescriptor',
Teresa Charlin18147332021-11-17 14:34:30 +0000534 'Convolution3dDescriptor',
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100535 'DepthwiseConvolution2dDescriptor',
536 'DetectionPostProcessDescriptor',
537 'NormalizationDescriptor',
538 'L2NormalizationDescriptor',
539 'BatchNormalizationDescriptor',
540 'InstanceNormalizationDescriptor',
541 'BatchToSpaceNdDescriptor',
542 'FakeQuantizationDescriptor',
Ryan OShea09a05222021-11-18 16:52:41 +0000543 'ReduceDescriptor',
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100544 'ResizeDescriptor',
545 'ReshapeDescriptor',
546 'SpaceToBatchNdDescriptor',
547 'SpaceToDepthDescriptor',
548 'LstmDescriptor',
549 'MeanDescriptor',
550 'PadDescriptor',
551 'SliceDescriptor',
552 'StackDescriptor',
553 'StridedSliceDescriptor',
554 'TransposeConvolution2dDescriptor',
Cathal Corbett2b4182f2021-11-18 10:28:47 +0000555 'TransposeDescriptor',
Jan Eilers841aca12020-08-12 14:59:06 +0100556 'ElementwiseUnaryDescriptor',
557 'FillDescriptor',
Cathal Corbettf0836e02021-11-18 18:17:38 +0000558 'GatherDescriptor',
Teresa Charlinf7b50112021-11-18 15:24:50 +0000559 'LogicalBinaryDescriptor',
560 'ChannelShuffleDescriptor'])
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100561class TestDescriptorMassChecks:
562
563 def test_desc_implemented(self, desc_name):
564 assert desc_name in generated_classes_names
565
566 def test_desc_equal(self, desc_name):
567 desc_class = next(filter(lambda x: x[0] == desc_name, generated_classes))[1]
568
569 assert desc_class() == desc_class()
570
571
572generated_classes = inspect.getmembers(generated, inspect.isclass)
573generated_classes_names = list(map(lambda x: x[0], generated_classes))
574@pytest.mark.parametrize("desc_name", ['ActivationDescriptor',
575 'ArgMinMaxDescriptor',
576 'PermuteDescriptor',
577 'SoftmaxDescriptor',
578 'ConcatDescriptor',
579 'SplitterDescriptor',
580 'Pooling2dDescriptor',
581 'FullyConnectedDescriptor',
582 'Convolution2dDescriptor',
Teresa Charlin18147332021-11-17 14:34:30 +0000583 'Convolution3dDescriptor',
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100584 'DepthwiseConvolution2dDescriptor',
585 'DetectionPostProcessDescriptor',
586 'NormalizationDescriptor',
587 'L2NormalizationDescriptor',
588 'BatchNormalizationDescriptor',
589 'InstanceNormalizationDescriptor',
590 'BatchToSpaceNdDescriptor',
591 'FakeQuantizationDescriptor',
Ryan OShea09a05222021-11-18 16:52:41 +0000592 'ReduceDescriptor',
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100593 'ResizeDescriptor',
594 'ReshapeDescriptor',
595 'SpaceToBatchNdDescriptor',
596 'SpaceToDepthDescriptor',
597 'LstmDescriptor',
598 'MeanDescriptor',
599 'PadDescriptor',
600 'SliceDescriptor',
601 'StackDescriptor',
602 'StridedSliceDescriptor',
603 'TransposeConvolution2dDescriptor',
Cathal Corbett2b4182f2021-11-18 10:28:47 +0000604 'TransposeDescriptor',
Jan Eilers841aca12020-08-12 14:59:06 +0100605 'ElementwiseUnaryDescriptor',
606 'FillDescriptor',
Cathal Corbettf0836e02021-11-18 18:17:38 +0000607 'GatherDescriptor',
Teresa Charlinf7b50112021-11-18 15:24:50 +0000608 'LogicalBinaryDescriptor',
609 'ChannelShuffleDescriptor'])
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100610class TestDescriptorMassChecks:
611
612 def test_desc_implemented(self, desc_name):
613 assert desc_name in generated_classes_names
614
615 def test_desc_equal(self, desc_name):
616 desc_class = next(filter(lambda x: x[0] == desc_name, generated_classes))[1]
617
618 assert desc_class() == desc_class()
619