blob: 54b79d7397d2d308e8d2b920c66138ff519ea4c4 [file] [log] [blame]
Teresa Charlin18147332021-11-17 14:34:30 +00001# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
Richard Burtondc0c6ed2020-04-08 16:39:05 +01002# SPDX-License-Identifier: MIT
3import inspect
4
5import pytest
6
7import pyarmnn as ann
8import numpy as np
9import pyarmnn._generated.pyarmnn as generated
10
11
12def test_activation_descriptor_default_values():
13 desc = ann.ActivationDescriptor()
14 assert desc.m_Function == ann.ActivationFunction_Sigmoid
15 assert desc.m_A == 0
16 assert desc.m_B == 0
17
18
19def test_argminmax_descriptor_default_values():
20 desc = ann.ArgMinMaxDescriptor()
21 assert desc.m_Function == ann.ArgMinMaxFunction_Min
22 assert desc.m_Axis == -1
23
24
25def test_batchnormalization_descriptor_default_values():
26 desc = ann.BatchNormalizationDescriptor()
27 assert desc.m_DataLayout == ann.DataLayout_NCHW
28 np.allclose(0.0001, desc.m_Eps)
29
30
31def test_batchtospacend_descriptor_default_values():
32 desc = ann.BatchToSpaceNdDescriptor()
33 assert desc.m_DataLayout == ann.DataLayout_NCHW
34 assert [1, 1] == desc.m_BlockShape
35 assert [(0, 0), (0, 0)] == desc.m_Crops
36
37
38def test_batchtospacend_descriptor_assignment():
39 desc = ann.BatchToSpaceNdDescriptor()
40 desc.m_BlockShape = (1, 2, 3)
41
42 ololo = [(1, 2), (3, 4)]
43 size_1 = len(ololo)
44 desc.m_Crops = ololo
45
46 assert size_1 == len(ololo)
47 desc.m_DataLayout = ann.DataLayout_NHWC
48 assert ann.DataLayout_NHWC == desc.m_DataLayout
49 assert [1, 2, 3] == desc.m_BlockShape
50 assert [(1, 2), (3, 4)] == desc.m_Crops
51
52
53@pytest.mark.parametrize("input_shape, value, vtype", [([-1], -1, 'int'), (("one", "two"), "'one'", 'str'),
54 ([1.33, 4.55], 1.33, 'float'),
55 ([{1: "one"}], "{1: 'one'}", 'dict')], ids=lambda x: str(x))
56def test_batchtospacend_descriptor_rubbish_assignment_shape(input_shape, value, vtype):
57 desc = ann.BatchToSpaceNdDescriptor()
58 with pytest.raises(TypeError) as err:
59 desc.m_BlockShape = input_shape
60
61 assert "Failed to convert python input value {} of type '{}' to C type 'j'".format(value, vtype) in str(err.value)
62
63
64@pytest.mark.parametrize("input_crops, value, vtype", [([(1, 2), (3, 4, 5)], '(3, 4, 5)', 'tuple'),
65 ([(1, 'one')], "(1, 'one')", 'tuple'),
66 ([-1], -1, 'int'),
67 ([(1, (1, 2))], '(1, (1, 2))', 'tuple'),
68 ([[1, [1, 2]]], '[1, [1, 2]]', 'list')
69 ], ids=lambda x: str(x))
70def test_batchtospacend_descriptor_rubbish_assignment_crops(input_crops, value, vtype):
71 desc = ann.BatchToSpaceNdDescriptor()
72 with pytest.raises(TypeError) as err:
73 desc.m_Crops = input_crops
74
75 assert "Failed to convert python input value {} of type '{}' to C type".format(value, vtype) in str(err.value)
76
77
78def test_batchtospacend_descriptor_empty_assignment():
79 desc = ann.BatchToSpaceNdDescriptor()
80 desc.m_BlockShape = []
81 assert [] == desc.m_BlockShape
82
83
84def test_batchtospacend_descriptor_ctor():
85 desc = ann.BatchToSpaceNdDescriptor([1, 2, 3], [(4, 5), (6, 7)])
86 assert desc.m_DataLayout == ann.DataLayout_NCHW
87 assert [1, 2, 3] == desc.m_BlockShape
88 assert [(4, 5), (6, 7)] == desc.m_Crops
89
90
91def test_convolution2d_descriptor_default_values():
92 desc = ann.Convolution2dDescriptor()
93 assert desc.m_PadLeft == 0
94 assert desc.m_PadTop == 0
95 assert desc.m_PadRight == 0
96 assert desc.m_PadBottom == 0
Teresa Charlinf2ed1b82020-11-24 15:11:54 +000097 assert desc.m_StrideX == 1
98 assert desc.m_StrideY == 1
Richard Burtondc0c6ed2020-04-08 16:39:05 +010099 assert desc.m_DilationX == 1
100 assert desc.m_DilationY == 1
101 assert desc.m_BiasEnabled == False
102 assert desc.m_DataLayout == ann.DataLayout_NCHW
103
Teresa Charlin18147332021-11-17 14:34:30 +0000104def test_convolution3d_descriptor_default_values():
105 desc = ann.Convolution3dDescriptor()
106 assert desc.m_PadLeft == 0
107 assert desc.m_PadTop == 0
108 assert desc.m_PadRight == 0
109 assert desc.m_PadBottom == 0
110 assert desc.m_PadFront == 0
111 assert desc.m_PadBack == 0
112 assert desc.m_StrideX == 1
113 assert desc.m_StrideY == 1
114 assert desc.m_StrideZ == 1
115 assert desc.m_DilationX == 1
116 assert desc.m_DilationY == 1
117 assert desc.m_DilationZ == 1
118 assert desc.m_BiasEnabled == False
119 assert desc.m_DataLayout == ann.DataLayout_NDHWC
120
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100121
122def test_depthtospace_descriptor_default_values():
123 desc = ann.DepthToSpaceDescriptor()
124 assert desc.m_BlockSize == 1
125 assert desc.m_DataLayout == ann.DataLayout_NHWC
126
127
128def test_depthwise_convolution2d_descriptor_default_values():
129 desc = ann.DepthwiseConvolution2dDescriptor()
130 assert desc.m_PadLeft == 0
131 assert desc.m_PadTop == 0
132 assert desc.m_PadRight == 0
133 assert desc.m_PadBottom == 0
Teresa Charlinf2ed1b82020-11-24 15:11:54 +0000134 assert desc.m_StrideX == 1
135 assert desc.m_StrideY == 1
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100136 assert desc.m_DilationX == 1
137 assert desc.m_DilationY == 1
138 assert desc.m_BiasEnabled == False
139 assert desc.m_DataLayout == ann.DataLayout_NCHW
140
141
142def test_detectionpostprocess_descriptor_default_values():
143 desc = ann.DetectionPostProcessDescriptor()
144 assert desc.m_MaxDetections == 0
145 assert desc.m_MaxClassesPerDetection == 1
146 assert desc.m_DetectionsPerClass == 1
147 assert desc.m_NmsScoreThreshold == 0
148 assert desc.m_NmsIouThreshold == 0
149 assert desc.m_NumClasses == 0
150 assert desc.m_UseRegularNms == False
151 assert desc.m_ScaleH == 0
152 assert desc.m_ScaleW == 0
153 assert desc.m_ScaleX == 0
154 assert desc.m_ScaleY == 0
155
156
157def test_fakequantization_descriptor_default_values():
158 desc = ann.FakeQuantizationDescriptor()
159 np.allclose(6, desc.m_Max)
160 np.allclose(-6, desc.m_Min)
161
162
Jan Eilers841aca12020-08-12 14:59:06 +0100163def test_fill_descriptor_default_values():
164 desc = ann.FillDescriptor()
165 np.allclose(0, desc.m_Value)
166
167
168def test_gather_descriptor_default_values():
169 desc = ann.GatherDescriptor()
170 assert desc.m_Axis == 0
171
172
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100173def test_fully_connected_descriptor_default_values():
174 desc = ann.FullyConnectedDescriptor()
175 assert desc.m_BiasEnabled == False
176 assert desc.m_TransposeWeightMatrix == False
177
178
179def test_instancenormalization_descriptor_default_values():
180 desc = ann.InstanceNormalizationDescriptor()
181 assert desc.m_Gamma == 1
182 assert desc.m_Beta == 0
183 assert desc.m_DataLayout == ann.DataLayout_NCHW
184 np.allclose(1e-12, desc.m_Eps)
185
186
187def test_lstm_descriptor_default_values():
188 desc = ann.LstmDescriptor()
189 assert desc.m_ActivationFunc == 1
190 assert desc.m_ClippingThresCell == 0
191 assert desc.m_ClippingThresProj == 0
192 assert desc.m_CifgEnabled == True
193 assert desc.m_PeepholeEnabled == False
194 assert desc.m_ProjectionEnabled == False
195 assert desc.m_LayerNormEnabled == False
196
197
198def test_l2normalization_descriptor_default_values():
199 desc = ann.L2NormalizationDescriptor()
200 assert desc.m_DataLayout == ann.DataLayout_NCHW
201 np.allclose(1e-12, desc.m_Eps)
202
203
204def test_mean_descriptor_default_values():
205 desc = ann.MeanDescriptor()
206 assert desc.m_KeepDims == False
207
208
209def test_normalization_descriptor_default_values():
210 desc = ann.NormalizationDescriptor()
211 assert desc.m_NormChannelType == ann.NormalizationAlgorithmChannel_Across
212 assert desc.m_NormMethodType == ann.NormalizationAlgorithmMethod_LocalBrightness
213 assert desc.m_NormSize == 0
214 assert desc.m_Alpha == 0
215 assert desc.m_Beta == 0
216 assert desc.m_K == 0
217 assert desc.m_DataLayout == ann.DataLayout_NCHW
218
219
220def test_origin_descriptor_default_values():
221 desc = ann.ConcatDescriptor()
222 assert 0 == desc.GetNumViews()
223 assert 0 == desc.GetNumDimensions()
224 assert 1 == desc.GetConcatAxis()
225
226
227def test_origin_descriptor_incorrect_views():
228 desc = ann.ConcatDescriptor(2, 2)
229 with pytest.raises(RuntimeError) as err:
230 desc.SetViewOriginCoord(1000, 100, 1000)
231 assert "Failed to set view origin coordinates." in str(err.value)
232
233
234def test_origin_descriptor_ctor():
235 desc = ann.ConcatDescriptor(2, 2)
236 value = 5
237 for i in range(desc.GetNumViews()):
238 for j in range(desc.GetNumDimensions()):
239 desc.SetViewOriginCoord(i, j, value+i)
240 desc.SetConcatAxis(1)
241
242 assert 2 == desc.GetNumViews()
243 assert 2 == desc.GetNumDimensions()
244 assert [5, 5] == desc.GetViewOrigin(0)
245 assert [6, 6] == desc.GetViewOrigin(1)
246 assert 1 == desc.GetConcatAxis()
247
248
249def test_pad_descriptor_default_values():
250 desc = ann.PadDescriptor()
251 assert desc.m_PadValue == 0
Teresa Charlincd3fdae2021-11-18 15:51:36 +0000252 assert desc.m_PaddingMode == ann.PaddingMode_Constant
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100253
254
255def test_permute_descriptor_default_values():
256 pv = ann.PermutationVector((0, 2, 3, 1))
257 desc = ann.PermuteDescriptor(pv)
258 assert desc.m_DimMappings.GetSize() == 4
259 assert desc.m_DimMappings[0] == 0
260 assert desc.m_DimMappings[1] == 2
261 assert desc.m_DimMappings[2] == 3
262 assert desc.m_DimMappings[3] == 1
263
264
265def test_pooling_descriptor_default_values():
266 desc = ann.Pooling2dDescriptor()
267 assert desc.m_PoolType == ann.PoolingAlgorithm_Max
268 assert desc.m_PadLeft == 0
269 assert desc.m_PadTop == 0
270 assert desc.m_PadRight == 0
271 assert desc.m_PadBottom == 0
272 assert desc.m_PoolHeight == 0
273 assert desc.m_PoolWidth == 0
274 assert desc.m_StrideX == 0
275 assert desc.m_StrideY == 0
276 assert desc.m_OutputShapeRounding == ann.OutputShapeRounding_Floor
277 assert desc.m_PaddingMethod == ann.PaddingMethod_Exclude
278 assert desc.m_DataLayout == ann.DataLayout_NCHW
279
280
281def test_reshape_descriptor_default_values():
282 desc = ann.ReshapeDescriptor()
283 # check the empty Targetshape
284 assert desc.m_TargetShape.GetNumDimensions() == 0
285
286
287def test_slice_descriptor_default_values():
288 desc = ann.SliceDescriptor()
289 assert desc.m_TargetWidth == 0
290 assert desc.m_TargetHeight == 0
291 assert desc.m_Method == ann.ResizeMethod_NearestNeighbor
292 assert desc.m_DataLayout == ann.DataLayout_NCHW
293
294
295def test_resize_descriptor_default_values():
296 desc = ann.ResizeDescriptor()
297 assert desc.m_TargetWidth == 0
298 assert desc.m_TargetHeight == 0
299 assert desc.m_Method == ann.ResizeMethod_NearestNeighbor
300 assert desc.m_DataLayout == ann.DataLayout_NCHW
Kevin May1c6e9762020-06-03 16:05:00 +0100301 assert desc.m_AlignCorners == False
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100302
303
304def test_spacetobatchnd_descriptor_default_values():
305 desc = ann.SpaceToBatchNdDescriptor()
306 assert desc.m_DataLayout == ann.DataLayout_NCHW
307
308
309def test_spacetodepth_descriptor_default_values():
310 desc = ann.SpaceToDepthDescriptor()
311 assert desc.m_BlockSize == 1
312 assert desc.m_DataLayout == ann.DataLayout_NHWC
313
314
315def test_stack_descriptor_default_values():
316 desc = ann.StackDescriptor()
317 assert desc.m_Axis == 0
318 assert desc.m_NumInputs == 0
319 # check the empty Inputshape
320 assert desc.m_InputShape.GetNumDimensions() == 0
321
322
323def test_slice_descriptor_default_values():
324 desc = ann.SliceDescriptor()
325 desc.m_Begin = [1, 2, 3, 4, 5]
326 desc.m_Size = (1, 2, 3, 4)
327
328 assert [1, 2, 3, 4, 5] == desc.m_Begin
329 assert [1, 2, 3, 4] == desc.m_Size
330
331
332def test_slice_descriptor_ctor():
333 desc = ann.SliceDescriptor([1, 2, 3, 4, 5], (1, 2, 3, 4))
334
335 assert [1, 2, 3, 4, 5] == desc.m_Begin
336 assert [1, 2, 3, 4] == desc.m_Size
337
338
339def test_strided_slice_descriptor_default_values():
340 desc = ann.StridedSliceDescriptor()
341 desc.m_Begin = [1, 2, 3, 4, 5]
342 desc.m_End = [6, 7, 8, 9, 10]
343 desc.m_Stride = (10, 10)
344 desc.m_BeginMask = 1
345 desc.m_EndMask = 2
346 desc.m_ShrinkAxisMask = 3
347 desc.m_EllipsisMask = 4
348 desc.m_NewAxisMask = 5
349
350 assert [1, 2, 3, 4, 5] == desc.m_Begin
351 assert [6, 7, 8, 9, 10] == desc.m_End
352 assert [10, 10] == desc.m_Stride
353 assert 1 == desc.m_BeginMask
354 assert 2 == desc.m_EndMask
355 assert 3 == desc.m_ShrinkAxisMask
356 assert 4 == desc.m_EllipsisMask
357 assert 5 == desc.m_NewAxisMask
358
359
360def test_strided_slice_descriptor_ctor():
361 desc = ann.StridedSliceDescriptor([1, 2, 3, 4, 5], [6, 7, 8, 9, 10], (10, 10))
362 desc.m_Begin = [1, 2, 3, 4, 5]
363 desc.m_End = [6, 7, 8, 9, 10]
364 desc.m_Stride = (10, 10)
365
366 assert [1, 2, 3, 4, 5] == desc.m_Begin
367 assert [6, 7, 8, 9, 10] == desc.m_End
368 assert [10, 10] == desc.m_Stride
369
370
371def test_softmax_descriptor_default_values():
372 desc = ann.SoftmaxDescriptor()
373 assert desc.m_Axis == -1
374 np.allclose(1.0, desc.m_Beta)
375
376
377def test_space_to_batch_nd_descriptor_default_values():
378 desc = ann.SpaceToBatchNdDescriptor()
379 assert [1, 1] == desc.m_BlockShape
380 assert [(0, 0), (0, 0)] == desc.m_PadList
381 assert ann.DataLayout_NCHW == desc.m_DataLayout
382
383
384def test_space_to_batch_nd_descriptor_assigned_values():
385 desc = ann.SpaceToBatchNdDescriptor()
386 desc.m_BlockShape = (90, 100)
387 desc.m_PadList = [(1, 2), (3, 4)]
388 assert [90, 100] == desc.m_BlockShape
389 assert [(1, 2), (3, 4)] == desc.m_PadList
390 assert ann.DataLayout_NCHW == desc.m_DataLayout
391
392
393def test_space_to_batch_nd_descriptor_ctor():
394 desc = ann.SpaceToBatchNdDescriptor((1, 2, 3), [(1, 2), (3, 4)])
395 assert [1, 2, 3] == desc.m_BlockShape
396 assert [(1, 2), (3, 4)] == desc.m_PadList
397 assert ann.DataLayout_NCHW == desc.m_DataLayout
398
399
400def test_transpose_convolution2d_descriptor_default_values():
Jan Eilers841aca12020-08-12 14:59:06 +0100401 desc = ann.TransposeConvolution2dDescriptor()
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100402 assert desc.m_PadLeft == 0
403 assert desc.m_PadTop == 0
404 assert desc.m_PadRight == 0
405 assert desc.m_PadBottom == 0
406 assert desc.m_StrideX == 0
407 assert desc.m_StrideY == 0
408 assert desc.m_BiasEnabled == False
409 assert desc.m_DataLayout == ann.DataLayout_NCHW
Jan Eilers841aca12020-08-12 14:59:06 +0100410 assert desc.m_OutputShapeEnabled == False
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100411
Cathal Corbett2b4182f2021-11-18 10:28:47 +0000412def test_transpose_descriptor_default_values():
413 pv = ann.PermutationVector((0, 3, 2, 1, 4))
414 desc = ann.TransposeDescriptor(pv)
415 assert desc.m_DimMappings.GetSize() == 5
416 assert desc.m_DimMappings[0] == 0
417 assert desc.m_DimMappings[1] == 3
418 assert desc.m_DimMappings[2] == 2
419 assert desc.m_DimMappings[3] == 1
420 assert desc.m_DimMappings[4] == 4
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100421
422def test_view_descriptor_default_values():
423 desc = ann.SplitterDescriptor()
424 assert 0 == desc.GetNumViews()
425 assert 0 == desc.GetNumDimensions()
426
427
428def test_elementwise_unary_descriptor_default_values():
429 desc = ann.ElementwiseUnaryDescriptor()
430 assert desc.m_Operation == ann.UnaryOperation_Abs
431
432
Cathal Corbettf0836e02021-11-18 18:17:38 +0000433def test_logical_binary_descriptor_default_values():
434 desc = ann.LogicalBinaryDescriptor()
435 assert desc.m_Operation == ann.LogicalBinaryOperation_LogicalAnd
436
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100437def test_view_descriptor_incorrect_input():
438 desc = ann.SplitterDescriptor(2, 3)
439 with pytest.raises(RuntimeError) as err:
440 desc.SetViewOriginCoord(1000, 100, 1000)
441 assert "Failed to set view origin coordinates." in str(err.value)
442
443 with pytest.raises(RuntimeError) as err:
444 desc.SetViewSize(1000, 100, 1000)
445 assert "Failed to set view size." in str(err.value)
446
447
448def test_view_descriptor_ctor():
449 desc = ann.SplitterDescriptor(2, 3)
450 value_size = 1
451 value_orig_coord = 5
452 for i in range(desc.GetNumViews()):
453 for j in range(desc.GetNumDimensions()):
454 desc.SetViewOriginCoord(i, j, value_orig_coord+i)
455 desc.SetViewSize(i, j, value_size+i)
456
457 assert 2 == desc.GetNumViews()
458 assert 3 == desc.GetNumDimensions()
459 assert [5, 5] == desc.GetViewOrigin(0)
460 assert [6, 6] == desc.GetViewOrigin(1)
461 assert [1, 1] == desc.GetViewSizes(0)
462 assert [2, 2] == desc.GetViewSizes(1)
463
464
465def test_createdescriptorforconcatenation_ctor():
466 input_shape_vector = [ann.TensorShape((2, 1)), ann.TensorShape((3, 1)), ann.TensorShape((4, 1))]
467 desc = ann.CreateDescriptorForConcatenation(input_shape_vector, 0)
468 assert 3 == desc.GetNumViews()
469 assert 0 == desc.GetConcatAxis()
470 assert 2 == desc.GetNumDimensions()
471 c = desc.GetViewOrigin(1)
472 d = desc.GetViewOrigin(0)
473
474
475def test_createdescriptorforconcatenation_wrong_shape_for_axis():
476 input_shape_vector = [ann.TensorShape((1, 2)), ann.TensorShape((3, 4)), ann.TensorShape((5, 6))]
477 with pytest.raises(RuntimeError) as err:
478 desc = ann.CreateDescriptorForConcatenation(input_shape_vector, 0)
479
480 assert "All inputs to concatenation must be the same size along all dimensions except the concatenation dimension" in str(
481 err.value)
482
483
484@pytest.mark.parametrize("input_shape_vector", [([-1, "one"]),
485 ([1.33, 4.55]),
486 ([{1: "one"}])], ids=lambda x: str(x))
487def test_createdescriptorforconcatenation_rubbish_assignment_shape_vector(input_shape_vector):
488 with pytest.raises(TypeError) as err:
489 desc = ann.CreateDescriptorForConcatenation(input_shape_vector, 0)
490
491 assert "in method 'CreateDescriptorForConcatenation', argument 1 of type 'std::vector< armnn::TensorShape,std::allocator< armnn::TensorShape > >'" in str(
492 err.value)
493
494
495generated_classes = inspect.getmembers(generated, inspect.isclass)
496generated_classes_names = list(map(lambda x: x[0], generated_classes))
497@pytest.mark.parametrize("desc_name", ['ActivationDescriptor',
498 'ArgMinMaxDescriptor',
499 'PermuteDescriptor',
500 'SoftmaxDescriptor',
501 'ConcatDescriptor',
502 'SplitterDescriptor',
503 'Pooling2dDescriptor',
504 'FullyConnectedDescriptor',
505 'Convolution2dDescriptor',
Teresa Charlin18147332021-11-17 14:34:30 +0000506 'Convolution3dDescriptor',
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100507 'DepthwiseConvolution2dDescriptor',
508 'DetectionPostProcessDescriptor',
509 'NormalizationDescriptor',
510 'L2NormalizationDescriptor',
511 'BatchNormalizationDescriptor',
512 'InstanceNormalizationDescriptor',
513 'BatchToSpaceNdDescriptor',
514 'FakeQuantizationDescriptor',
515 'ResizeDescriptor',
516 'ReshapeDescriptor',
517 'SpaceToBatchNdDescriptor',
518 'SpaceToDepthDescriptor',
519 'LstmDescriptor',
520 'MeanDescriptor',
521 'PadDescriptor',
522 'SliceDescriptor',
523 'StackDescriptor',
524 'StridedSliceDescriptor',
525 'TransposeConvolution2dDescriptor',
Cathal Corbett2b4182f2021-11-18 10:28:47 +0000526 'TransposeDescriptor',
Jan Eilers841aca12020-08-12 14:59:06 +0100527 'ElementwiseUnaryDescriptor',
528 'FillDescriptor',
Cathal Corbettf0836e02021-11-18 18:17:38 +0000529 'GatherDescriptor',
530 'LogicalBinaryDescriptor'])
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100531class TestDescriptorMassChecks:
532
533 def test_desc_implemented(self, desc_name):
534 assert desc_name in generated_classes_names
535
536 def test_desc_equal(self, desc_name):
537 desc_class = next(filter(lambda x: x[0] == desc_name, generated_classes))[1]
538
539 assert desc_class() == desc_class()
540
541
542generated_classes = inspect.getmembers(generated, inspect.isclass)
543generated_classes_names = list(map(lambda x: x[0], generated_classes))
544@pytest.mark.parametrize("desc_name", ['ActivationDescriptor',
545 'ArgMinMaxDescriptor',
546 'PermuteDescriptor',
547 'SoftmaxDescriptor',
548 'ConcatDescriptor',
549 'SplitterDescriptor',
550 'Pooling2dDescriptor',
551 'FullyConnectedDescriptor',
552 'Convolution2dDescriptor',
Teresa Charlin18147332021-11-17 14:34:30 +0000553 'Convolution3dDescriptor',
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100554 'DepthwiseConvolution2dDescriptor',
555 'DetectionPostProcessDescriptor',
556 'NormalizationDescriptor',
557 'L2NormalizationDescriptor',
558 'BatchNormalizationDescriptor',
559 'InstanceNormalizationDescriptor',
560 'BatchToSpaceNdDescriptor',
561 'FakeQuantizationDescriptor',
562 'ResizeDescriptor',
563 'ReshapeDescriptor',
564 'SpaceToBatchNdDescriptor',
565 'SpaceToDepthDescriptor',
566 'LstmDescriptor',
567 'MeanDescriptor',
568 'PadDescriptor',
569 'SliceDescriptor',
570 'StackDescriptor',
571 'StridedSliceDescriptor',
572 'TransposeConvolution2dDescriptor',
Cathal Corbett2b4182f2021-11-18 10:28:47 +0000573 'TransposeDescriptor',
Jan Eilers841aca12020-08-12 14:59:06 +0100574 'ElementwiseUnaryDescriptor',
575 'FillDescriptor',
Cathal Corbettf0836e02021-11-18 18:17:38 +0000576 'GatherDescriptor',
577 'LogicalBinaryDescriptor'])
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100578class TestDescriptorMassChecks:
579
580 def test_desc_implemented(self, desc_name):
581 assert desc_name in generated_classes_names
582
583 def test_desc_equal(self, desc_name):
584 desc_class = next(filter(lambda x: x[0] == desc_name, generated_classes))[1]
585
586 assert desc_class() == desc_class()
587