blob: 80c5359eb2604b4cb3c1a3d6404b2871f043fe30 [file] [log] [blame]
Teresa Charlin18147332021-11-17 14:34:30 +00001# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
Richard Burtondc0c6ed2020-04-08 16:39:05 +01002# SPDX-License-Identifier: MIT
3import inspect
4
5import pytest
6
7import pyarmnn as ann
8import numpy as np
9import pyarmnn._generated.pyarmnn as generated
10
11
12def test_activation_descriptor_default_values():
13 desc = ann.ActivationDescriptor()
14 assert desc.m_Function == ann.ActivationFunction_Sigmoid
15 assert desc.m_A == 0
16 assert desc.m_B == 0
17
18
19def test_argminmax_descriptor_default_values():
20 desc = ann.ArgMinMaxDescriptor()
21 assert desc.m_Function == ann.ArgMinMaxFunction_Min
22 assert desc.m_Axis == -1
23
24
25def test_batchnormalization_descriptor_default_values():
26 desc = ann.BatchNormalizationDescriptor()
27 assert desc.m_DataLayout == ann.DataLayout_NCHW
28 np.allclose(0.0001, desc.m_Eps)
29
30
31def test_batchtospacend_descriptor_default_values():
32 desc = ann.BatchToSpaceNdDescriptor()
33 assert desc.m_DataLayout == ann.DataLayout_NCHW
34 assert [1, 1] == desc.m_BlockShape
35 assert [(0, 0), (0, 0)] == desc.m_Crops
36
37
38def test_batchtospacend_descriptor_assignment():
39 desc = ann.BatchToSpaceNdDescriptor()
40 desc.m_BlockShape = (1, 2, 3)
41
42 ololo = [(1, 2), (3, 4)]
43 size_1 = len(ololo)
44 desc.m_Crops = ololo
45
46 assert size_1 == len(ololo)
47 desc.m_DataLayout = ann.DataLayout_NHWC
48 assert ann.DataLayout_NHWC == desc.m_DataLayout
49 assert [1, 2, 3] == desc.m_BlockShape
50 assert [(1, 2), (3, 4)] == desc.m_Crops
51
52
53@pytest.mark.parametrize("input_shape, value, vtype", [([-1], -1, 'int'), (("one", "two"), "'one'", 'str'),
54 ([1.33, 4.55], 1.33, 'float'),
55 ([{1: "one"}], "{1: 'one'}", 'dict')], ids=lambda x: str(x))
56def test_batchtospacend_descriptor_rubbish_assignment_shape(input_shape, value, vtype):
57 desc = ann.BatchToSpaceNdDescriptor()
58 with pytest.raises(TypeError) as err:
59 desc.m_BlockShape = input_shape
60
61 assert "Failed to convert python input value {} of type '{}' to C type 'j'".format(value, vtype) in str(err.value)
62
63
64@pytest.mark.parametrize("input_crops, value, vtype", [([(1, 2), (3, 4, 5)], '(3, 4, 5)', 'tuple'),
65 ([(1, 'one')], "(1, 'one')", 'tuple'),
66 ([-1], -1, 'int'),
67 ([(1, (1, 2))], '(1, (1, 2))', 'tuple'),
68 ([[1, [1, 2]]], '[1, [1, 2]]', 'list')
69 ], ids=lambda x: str(x))
70def test_batchtospacend_descriptor_rubbish_assignment_crops(input_crops, value, vtype):
71 desc = ann.BatchToSpaceNdDescriptor()
72 with pytest.raises(TypeError) as err:
73 desc.m_Crops = input_crops
74
75 assert "Failed to convert python input value {} of type '{}' to C type".format(value, vtype) in str(err.value)
76
77
78def test_batchtospacend_descriptor_empty_assignment():
79 desc = ann.BatchToSpaceNdDescriptor()
80 desc.m_BlockShape = []
81 assert [] == desc.m_BlockShape
82
83
84def test_batchtospacend_descriptor_ctor():
85 desc = ann.BatchToSpaceNdDescriptor([1, 2, 3], [(4, 5), (6, 7)])
86 assert desc.m_DataLayout == ann.DataLayout_NCHW
87 assert [1, 2, 3] == desc.m_BlockShape
88 assert [(4, 5), (6, 7)] == desc.m_Crops
89
90
Teresa Charlinf7b50112021-11-18 15:24:50 +000091def test_channelshuffle_descriptor_default_values():
92 desc = ann.ChannelShuffleDescriptor()
93 assert desc.m_Axis == 0
94 assert desc.m_NumGroups == 0
95
Richard Burtondc0c6ed2020-04-08 16:39:05 +010096def test_convolution2d_descriptor_default_values():
97 desc = ann.Convolution2dDescriptor()
98 assert desc.m_PadLeft == 0
99 assert desc.m_PadTop == 0
100 assert desc.m_PadRight == 0
101 assert desc.m_PadBottom == 0
Teresa Charlinf2ed1b82020-11-24 15:11:54 +0000102 assert desc.m_StrideX == 1
103 assert desc.m_StrideY == 1
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100104 assert desc.m_DilationX == 1
105 assert desc.m_DilationY == 1
106 assert desc.m_BiasEnabled == False
107 assert desc.m_DataLayout == ann.DataLayout_NCHW
108
Teresa Charlin18147332021-11-17 14:34:30 +0000109def test_convolution3d_descriptor_default_values():
110 desc = ann.Convolution3dDescriptor()
111 assert desc.m_PadLeft == 0
112 assert desc.m_PadTop == 0
113 assert desc.m_PadRight == 0
114 assert desc.m_PadBottom == 0
115 assert desc.m_PadFront == 0
116 assert desc.m_PadBack == 0
117 assert desc.m_StrideX == 1
118 assert desc.m_StrideY == 1
119 assert desc.m_StrideZ == 1
120 assert desc.m_DilationX == 1
121 assert desc.m_DilationY == 1
122 assert desc.m_DilationZ == 1
123 assert desc.m_BiasEnabled == False
124 assert desc.m_DataLayout == ann.DataLayout_NDHWC
125
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100126
127def test_depthtospace_descriptor_default_values():
128 desc = ann.DepthToSpaceDescriptor()
129 assert desc.m_BlockSize == 1
130 assert desc.m_DataLayout == ann.DataLayout_NHWC
131
132
133def test_depthwise_convolution2d_descriptor_default_values():
134 desc = ann.DepthwiseConvolution2dDescriptor()
135 assert desc.m_PadLeft == 0
136 assert desc.m_PadTop == 0
137 assert desc.m_PadRight == 0
138 assert desc.m_PadBottom == 0
Teresa Charlinf2ed1b82020-11-24 15:11:54 +0000139 assert desc.m_StrideX == 1
140 assert desc.m_StrideY == 1
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100141 assert desc.m_DilationX == 1
142 assert desc.m_DilationY == 1
143 assert desc.m_BiasEnabled == False
144 assert desc.m_DataLayout == ann.DataLayout_NCHW
145
146
147def test_detectionpostprocess_descriptor_default_values():
148 desc = ann.DetectionPostProcessDescriptor()
149 assert desc.m_MaxDetections == 0
150 assert desc.m_MaxClassesPerDetection == 1
151 assert desc.m_DetectionsPerClass == 1
152 assert desc.m_NmsScoreThreshold == 0
153 assert desc.m_NmsIouThreshold == 0
154 assert desc.m_NumClasses == 0
155 assert desc.m_UseRegularNms == False
156 assert desc.m_ScaleH == 0
157 assert desc.m_ScaleW == 0
158 assert desc.m_ScaleX == 0
159 assert desc.m_ScaleY == 0
160
161
162def test_fakequantization_descriptor_default_values():
163 desc = ann.FakeQuantizationDescriptor()
164 np.allclose(6, desc.m_Max)
165 np.allclose(-6, desc.m_Min)
166
167
Jan Eilers841aca12020-08-12 14:59:06 +0100168def test_fill_descriptor_default_values():
169 desc = ann.FillDescriptor()
170 np.allclose(0, desc.m_Value)
171
172
173def test_gather_descriptor_default_values():
174 desc = ann.GatherDescriptor()
175 assert desc.m_Axis == 0
176
177
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100178def test_fully_connected_descriptor_default_values():
179 desc = ann.FullyConnectedDescriptor()
180 assert desc.m_BiasEnabled == False
181 assert desc.m_TransposeWeightMatrix == False
182
183
184def test_instancenormalization_descriptor_default_values():
185 desc = ann.InstanceNormalizationDescriptor()
186 assert desc.m_Gamma == 1
187 assert desc.m_Beta == 0
188 assert desc.m_DataLayout == ann.DataLayout_NCHW
189 np.allclose(1e-12, desc.m_Eps)
190
191
192def test_lstm_descriptor_default_values():
193 desc = ann.LstmDescriptor()
194 assert desc.m_ActivationFunc == 1
195 assert desc.m_ClippingThresCell == 0
196 assert desc.m_ClippingThresProj == 0
197 assert desc.m_CifgEnabled == True
198 assert desc.m_PeepholeEnabled == False
199 assert desc.m_ProjectionEnabled == False
200 assert desc.m_LayerNormEnabled == False
201
202
203def test_l2normalization_descriptor_default_values():
204 desc = ann.L2NormalizationDescriptor()
205 assert desc.m_DataLayout == ann.DataLayout_NCHW
206 np.allclose(1e-12, desc.m_Eps)
207
208
209def test_mean_descriptor_default_values():
210 desc = ann.MeanDescriptor()
211 assert desc.m_KeepDims == False
212
213
214def test_normalization_descriptor_default_values():
215 desc = ann.NormalizationDescriptor()
216 assert desc.m_NormChannelType == ann.NormalizationAlgorithmChannel_Across
217 assert desc.m_NormMethodType == ann.NormalizationAlgorithmMethod_LocalBrightness
218 assert desc.m_NormSize == 0
219 assert desc.m_Alpha == 0
220 assert desc.m_Beta == 0
221 assert desc.m_K == 0
222 assert desc.m_DataLayout == ann.DataLayout_NCHW
223
224
225def test_origin_descriptor_default_values():
226 desc = ann.ConcatDescriptor()
227 assert 0 == desc.GetNumViews()
228 assert 0 == desc.GetNumDimensions()
229 assert 1 == desc.GetConcatAxis()
230
231
232def test_origin_descriptor_incorrect_views():
233 desc = ann.ConcatDescriptor(2, 2)
234 with pytest.raises(RuntimeError) as err:
235 desc.SetViewOriginCoord(1000, 100, 1000)
236 assert "Failed to set view origin coordinates." in str(err.value)
237
238
239def test_origin_descriptor_ctor():
240 desc = ann.ConcatDescriptor(2, 2)
241 value = 5
242 for i in range(desc.GetNumViews()):
243 for j in range(desc.GetNumDimensions()):
244 desc.SetViewOriginCoord(i, j, value+i)
245 desc.SetConcatAxis(1)
246
247 assert 2 == desc.GetNumViews()
248 assert 2 == desc.GetNumDimensions()
249 assert [5, 5] == desc.GetViewOrigin(0)
250 assert [6, 6] == desc.GetViewOrigin(1)
251 assert 1 == desc.GetConcatAxis()
252
253
254def test_pad_descriptor_default_values():
255 desc = ann.PadDescriptor()
256 assert desc.m_PadValue == 0
Teresa Charlincd3fdae2021-11-18 15:51:36 +0000257 assert desc.m_PaddingMode == ann.PaddingMode_Constant
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100258
259
260def test_permute_descriptor_default_values():
261 pv = ann.PermutationVector((0, 2, 3, 1))
262 desc = ann.PermuteDescriptor(pv)
263 assert desc.m_DimMappings.GetSize() == 4
264 assert desc.m_DimMappings[0] == 0
265 assert desc.m_DimMappings[1] == 2
266 assert desc.m_DimMappings[2] == 3
267 assert desc.m_DimMappings[3] == 1
268
269
270def test_pooling_descriptor_default_values():
271 desc = ann.Pooling2dDescriptor()
272 assert desc.m_PoolType == ann.PoolingAlgorithm_Max
273 assert desc.m_PadLeft == 0
274 assert desc.m_PadTop == 0
275 assert desc.m_PadRight == 0
276 assert desc.m_PadBottom == 0
277 assert desc.m_PoolHeight == 0
278 assert desc.m_PoolWidth == 0
279 assert desc.m_StrideX == 0
280 assert desc.m_StrideY == 0
281 assert desc.m_OutputShapeRounding == ann.OutputShapeRounding_Floor
282 assert desc.m_PaddingMethod == ann.PaddingMethod_Exclude
283 assert desc.m_DataLayout == ann.DataLayout_NCHW
284
285
286def test_reshape_descriptor_default_values():
287 desc = ann.ReshapeDescriptor()
288 # check the empty Targetshape
289 assert desc.m_TargetShape.GetNumDimensions() == 0
290
291
292def test_slice_descriptor_default_values():
293 desc = ann.SliceDescriptor()
294 assert desc.m_TargetWidth == 0
295 assert desc.m_TargetHeight == 0
296 assert desc.m_Method == ann.ResizeMethod_NearestNeighbor
297 assert desc.m_DataLayout == ann.DataLayout_NCHW
298
299
300def test_resize_descriptor_default_values():
301 desc = ann.ResizeDescriptor()
302 assert desc.m_TargetWidth == 0
303 assert desc.m_TargetHeight == 0
304 assert desc.m_Method == ann.ResizeMethod_NearestNeighbor
305 assert desc.m_DataLayout == ann.DataLayout_NCHW
Kevin May1c6e9762020-06-03 16:05:00 +0100306 assert desc.m_AlignCorners == False
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100307
308
309def test_spacetobatchnd_descriptor_default_values():
310 desc = ann.SpaceToBatchNdDescriptor()
311 assert desc.m_DataLayout == ann.DataLayout_NCHW
312
313
314def test_spacetodepth_descriptor_default_values():
315 desc = ann.SpaceToDepthDescriptor()
316 assert desc.m_BlockSize == 1
317 assert desc.m_DataLayout == ann.DataLayout_NHWC
318
319
320def test_stack_descriptor_default_values():
321 desc = ann.StackDescriptor()
322 assert desc.m_Axis == 0
323 assert desc.m_NumInputs == 0
324 # check the empty Inputshape
325 assert desc.m_InputShape.GetNumDimensions() == 0
326
327
328def test_slice_descriptor_default_values():
329 desc = ann.SliceDescriptor()
330 desc.m_Begin = [1, 2, 3, 4, 5]
331 desc.m_Size = (1, 2, 3, 4)
332
333 assert [1, 2, 3, 4, 5] == desc.m_Begin
334 assert [1, 2, 3, 4] == desc.m_Size
335
336
337def test_slice_descriptor_ctor():
338 desc = ann.SliceDescriptor([1, 2, 3, 4, 5], (1, 2, 3, 4))
339
340 assert [1, 2, 3, 4, 5] == desc.m_Begin
341 assert [1, 2, 3, 4] == desc.m_Size
342
343
344def test_strided_slice_descriptor_default_values():
345 desc = ann.StridedSliceDescriptor()
346 desc.m_Begin = [1, 2, 3, 4, 5]
347 desc.m_End = [6, 7, 8, 9, 10]
348 desc.m_Stride = (10, 10)
349 desc.m_BeginMask = 1
350 desc.m_EndMask = 2
351 desc.m_ShrinkAxisMask = 3
352 desc.m_EllipsisMask = 4
353 desc.m_NewAxisMask = 5
354
355 assert [1, 2, 3, 4, 5] == desc.m_Begin
356 assert [6, 7, 8, 9, 10] == desc.m_End
357 assert [10, 10] == desc.m_Stride
358 assert 1 == desc.m_BeginMask
359 assert 2 == desc.m_EndMask
360 assert 3 == desc.m_ShrinkAxisMask
361 assert 4 == desc.m_EllipsisMask
362 assert 5 == desc.m_NewAxisMask
363
364
365def test_strided_slice_descriptor_ctor():
366 desc = ann.StridedSliceDescriptor([1, 2, 3, 4, 5], [6, 7, 8, 9, 10], (10, 10))
367 desc.m_Begin = [1, 2, 3, 4, 5]
368 desc.m_End = [6, 7, 8, 9, 10]
369 desc.m_Stride = (10, 10)
370
371 assert [1, 2, 3, 4, 5] == desc.m_Begin
372 assert [6, 7, 8, 9, 10] == desc.m_End
373 assert [10, 10] == desc.m_Stride
374
375
376def test_softmax_descriptor_default_values():
377 desc = ann.SoftmaxDescriptor()
378 assert desc.m_Axis == -1
379 np.allclose(1.0, desc.m_Beta)
380
381
382def test_space_to_batch_nd_descriptor_default_values():
383 desc = ann.SpaceToBatchNdDescriptor()
384 assert [1, 1] == desc.m_BlockShape
385 assert [(0, 0), (0, 0)] == desc.m_PadList
386 assert ann.DataLayout_NCHW == desc.m_DataLayout
387
388
389def test_space_to_batch_nd_descriptor_assigned_values():
390 desc = ann.SpaceToBatchNdDescriptor()
391 desc.m_BlockShape = (90, 100)
392 desc.m_PadList = [(1, 2), (3, 4)]
393 assert [90, 100] == desc.m_BlockShape
394 assert [(1, 2), (3, 4)] == desc.m_PadList
395 assert ann.DataLayout_NCHW == desc.m_DataLayout
396
397
398def test_space_to_batch_nd_descriptor_ctor():
399 desc = ann.SpaceToBatchNdDescriptor((1, 2, 3), [(1, 2), (3, 4)])
400 assert [1, 2, 3] == desc.m_BlockShape
401 assert [(1, 2), (3, 4)] == desc.m_PadList
402 assert ann.DataLayout_NCHW == desc.m_DataLayout
403
404
405def test_transpose_convolution2d_descriptor_default_values():
Jan Eilers841aca12020-08-12 14:59:06 +0100406 desc = ann.TransposeConvolution2dDescriptor()
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100407 assert desc.m_PadLeft == 0
408 assert desc.m_PadTop == 0
409 assert desc.m_PadRight == 0
410 assert desc.m_PadBottom == 0
411 assert desc.m_StrideX == 0
412 assert desc.m_StrideY == 0
413 assert desc.m_BiasEnabled == False
414 assert desc.m_DataLayout == ann.DataLayout_NCHW
Jan Eilers841aca12020-08-12 14:59:06 +0100415 assert desc.m_OutputShapeEnabled == False
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100416
Cathal Corbett2b4182f2021-11-18 10:28:47 +0000417def test_transpose_descriptor_default_values():
418 pv = ann.PermutationVector((0, 3, 2, 1, 4))
419 desc = ann.TransposeDescriptor(pv)
420 assert desc.m_DimMappings.GetSize() == 5
421 assert desc.m_DimMappings[0] == 0
422 assert desc.m_DimMappings[1] == 3
423 assert desc.m_DimMappings[2] == 2
424 assert desc.m_DimMappings[3] == 1
425 assert desc.m_DimMappings[4] == 4
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100426
427def test_view_descriptor_default_values():
428 desc = ann.SplitterDescriptor()
429 assert 0 == desc.GetNumViews()
430 assert 0 == desc.GetNumDimensions()
431
432
433def test_elementwise_unary_descriptor_default_values():
434 desc = ann.ElementwiseUnaryDescriptor()
435 assert desc.m_Operation == ann.UnaryOperation_Abs
436
437
Cathal Corbettf0836e02021-11-18 18:17:38 +0000438def test_logical_binary_descriptor_default_values():
439 desc = ann.LogicalBinaryDescriptor()
440 assert desc.m_Operation == ann.LogicalBinaryOperation_LogicalAnd
441
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100442def test_view_descriptor_incorrect_input():
443 desc = ann.SplitterDescriptor(2, 3)
444 with pytest.raises(RuntimeError) as err:
445 desc.SetViewOriginCoord(1000, 100, 1000)
446 assert "Failed to set view origin coordinates." in str(err.value)
447
448 with pytest.raises(RuntimeError) as err:
449 desc.SetViewSize(1000, 100, 1000)
450 assert "Failed to set view size." in str(err.value)
451
452
453def test_view_descriptor_ctor():
454 desc = ann.SplitterDescriptor(2, 3)
455 value_size = 1
456 value_orig_coord = 5
457 for i in range(desc.GetNumViews()):
458 for j in range(desc.GetNumDimensions()):
459 desc.SetViewOriginCoord(i, j, value_orig_coord+i)
460 desc.SetViewSize(i, j, value_size+i)
461
462 assert 2 == desc.GetNumViews()
463 assert 3 == desc.GetNumDimensions()
464 assert [5, 5] == desc.GetViewOrigin(0)
465 assert [6, 6] == desc.GetViewOrigin(1)
466 assert [1, 1] == desc.GetViewSizes(0)
467 assert [2, 2] == desc.GetViewSizes(1)
468
469
470def test_createdescriptorforconcatenation_ctor():
471 input_shape_vector = [ann.TensorShape((2, 1)), ann.TensorShape((3, 1)), ann.TensorShape((4, 1))]
472 desc = ann.CreateDescriptorForConcatenation(input_shape_vector, 0)
473 assert 3 == desc.GetNumViews()
474 assert 0 == desc.GetConcatAxis()
475 assert 2 == desc.GetNumDimensions()
476 c = desc.GetViewOrigin(1)
477 d = desc.GetViewOrigin(0)
478
479
480def test_createdescriptorforconcatenation_wrong_shape_for_axis():
481 input_shape_vector = [ann.TensorShape((1, 2)), ann.TensorShape((3, 4)), ann.TensorShape((5, 6))]
482 with pytest.raises(RuntimeError) as err:
483 desc = ann.CreateDescriptorForConcatenation(input_shape_vector, 0)
484
485 assert "All inputs to concatenation must be the same size along all dimensions except the concatenation dimension" in str(
486 err.value)
487
488
489@pytest.mark.parametrize("input_shape_vector", [([-1, "one"]),
490 ([1.33, 4.55]),
491 ([{1: "one"}])], ids=lambda x: str(x))
492def test_createdescriptorforconcatenation_rubbish_assignment_shape_vector(input_shape_vector):
493 with pytest.raises(TypeError) as err:
494 desc = ann.CreateDescriptorForConcatenation(input_shape_vector, 0)
495
496 assert "in method 'CreateDescriptorForConcatenation', argument 1 of type 'std::vector< armnn::TensorShape,std::allocator< armnn::TensorShape > >'" in str(
497 err.value)
498
499
500generated_classes = inspect.getmembers(generated, inspect.isclass)
501generated_classes_names = list(map(lambda x: x[0], generated_classes))
502@pytest.mark.parametrize("desc_name", ['ActivationDescriptor',
503 'ArgMinMaxDescriptor',
504 'PermuteDescriptor',
505 'SoftmaxDescriptor',
506 'ConcatDescriptor',
507 'SplitterDescriptor',
508 'Pooling2dDescriptor',
509 'FullyConnectedDescriptor',
510 'Convolution2dDescriptor',
Teresa Charlin18147332021-11-17 14:34:30 +0000511 'Convolution3dDescriptor',
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100512 'DepthwiseConvolution2dDescriptor',
513 'DetectionPostProcessDescriptor',
514 'NormalizationDescriptor',
515 'L2NormalizationDescriptor',
516 'BatchNormalizationDescriptor',
517 'InstanceNormalizationDescriptor',
518 'BatchToSpaceNdDescriptor',
519 'FakeQuantizationDescriptor',
520 'ResizeDescriptor',
521 'ReshapeDescriptor',
522 'SpaceToBatchNdDescriptor',
523 'SpaceToDepthDescriptor',
524 'LstmDescriptor',
525 'MeanDescriptor',
526 'PadDescriptor',
527 'SliceDescriptor',
528 'StackDescriptor',
529 'StridedSliceDescriptor',
530 'TransposeConvolution2dDescriptor',
Cathal Corbett2b4182f2021-11-18 10:28:47 +0000531 'TransposeDescriptor',
Jan Eilers841aca12020-08-12 14:59:06 +0100532 'ElementwiseUnaryDescriptor',
533 'FillDescriptor',
Cathal Corbettf0836e02021-11-18 18:17:38 +0000534 'GatherDescriptor',
Teresa Charlinf7b50112021-11-18 15:24:50 +0000535 'LogicalBinaryDescriptor',
536 'ChannelShuffleDescriptor'])
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100537class TestDescriptorMassChecks:
538
539 def test_desc_implemented(self, desc_name):
540 assert desc_name in generated_classes_names
541
542 def test_desc_equal(self, desc_name):
543 desc_class = next(filter(lambda x: x[0] == desc_name, generated_classes))[1]
544
545 assert desc_class() == desc_class()
546
547
548generated_classes = inspect.getmembers(generated, inspect.isclass)
549generated_classes_names = list(map(lambda x: x[0], generated_classes))
550@pytest.mark.parametrize("desc_name", ['ActivationDescriptor',
551 'ArgMinMaxDescriptor',
552 'PermuteDescriptor',
553 'SoftmaxDescriptor',
554 'ConcatDescriptor',
555 'SplitterDescriptor',
556 'Pooling2dDescriptor',
557 'FullyConnectedDescriptor',
558 'Convolution2dDescriptor',
Teresa Charlin18147332021-11-17 14:34:30 +0000559 'Convolution3dDescriptor',
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100560 'DepthwiseConvolution2dDescriptor',
561 'DetectionPostProcessDescriptor',
562 'NormalizationDescriptor',
563 'L2NormalizationDescriptor',
564 'BatchNormalizationDescriptor',
565 'InstanceNormalizationDescriptor',
566 'BatchToSpaceNdDescriptor',
567 'FakeQuantizationDescriptor',
568 'ResizeDescriptor',
569 'ReshapeDescriptor',
570 'SpaceToBatchNdDescriptor',
571 'SpaceToDepthDescriptor',
572 'LstmDescriptor',
573 'MeanDescriptor',
574 'PadDescriptor',
575 'SliceDescriptor',
576 'StackDescriptor',
577 'StridedSliceDescriptor',
578 'TransposeConvolution2dDescriptor',
Cathal Corbett2b4182f2021-11-18 10:28:47 +0000579 'TransposeDescriptor',
Jan Eilers841aca12020-08-12 14:59:06 +0100580 'ElementwiseUnaryDescriptor',
581 'FillDescriptor',
Cathal Corbettf0836e02021-11-18 18:17:38 +0000582 'GatherDescriptor',
Teresa Charlinf7b50112021-11-18 15:24:50 +0000583 'LogicalBinaryDescriptor',
584 'ChannelShuffleDescriptor'])
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100585class TestDescriptorMassChecks:
586
587 def test_desc_implemented(self, desc_name):
588 assert desc_name in generated_classes_names
589
590 def test_desc_equal(self, desc_name):
591 desc_class = next(filter(lambda x: x[0] == desc_name, generated_classes))[1]
592
593 assert desc_class() == desc_class()
594