blob: 0d15be5e73ff0a66de799133d31bfc6cb7383eca [file] [log] [blame]
Richard Burtondc0c6ed2020-04-08 16:39:05 +01001# Copyright © 2020 Arm Ltd. All rights reserved.
2# SPDX-License-Identifier: MIT
3import pytest
4
5import pyarmnn as ann
6
7
8@pytest.fixture(scope="function")
9def network():
10 return ann.INetwork()
11
12
13class TestIInputIOutputIConnectable:
14
15 def test_input_slot(self, network):
16 # Create input, addition & output layer
17 input1 = network.AddInputLayer(0, "input1")
18 input2 = network.AddInputLayer(1, "input2")
19 add = network.AddAdditionLayer("addition")
20 output = network.AddOutputLayer(0, "output")
21
22 # Connect the input/output slots for each layer
23 input1.GetOutputSlot(0).Connect(add.GetInputSlot(0))
24 input2.GetOutputSlot(0).Connect(add.GetInputSlot(1))
25 add.GetOutputSlot(0).Connect(output.GetInputSlot(0))
26
27 # Check IInputSlot GetConnection()
28 input_slot = add.GetInputSlot(0)
29 input_slot_connection = input_slot.GetConnection()
30
31 assert isinstance(input_slot_connection, ann.IOutputSlot)
32
33 del input_slot_connection
34
35 assert input_slot.GetConnection()
36 assert isinstance(input_slot.GetConnection(), ann.IOutputSlot)
37
38 del input_slot
39
40 assert add.GetInputSlot(0)
41
42 def test_output_slot(self, network):
43
44 # Create input, addition & output layer
45 input1 = network.AddInputLayer(0, "input1")
46 input2 = network.AddInputLayer(1, "input2")
47 add = network.AddAdditionLayer("addition")
48 output = network.AddOutputLayer(0, "output")
49
50 # Connect the input/output slots for each layer
51 input1.GetOutputSlot(0).Connect(add.GetInputSlot(0))
52 input2.GetOutputSlot(0).Connect(add.GetInputSlot(1))
53 add.GetOutputSlot(0).Connect(output.GetInputSlot(0))
54
55 # Check IInputSlot GetConnection()
56 add_get_input_connection = add.GetInputSlot(0).GetConnection()
57 output_get_input_connection = output.GetInputSlot(0).GetConnection()
58
59 # Check IOutputSlot GetConnection()
60 add_get_output_connect = add.GetOutputSlot(0).GetConnection(0)
61 assert isinstance(add_get_output_connect.GetConnection(), ann.IOutputSlot)
62
63 # Test IOutputSlot GetNumConnections() & CalculateIndexOnOwner()
64 assert add_get_input_connection.GetNumConnections() == 1
65 assert len(add_get_input_connection) == 1
66 assert add_get_input_connection[0]
67 assert add_get_input_connection.CalculateIndexOnOwner() == 0
68
69 # Check GetOwningLayerGuid(). Check that it is different for add and output layer
70 assert add_get_input_connection.GetOwningLayerGuid() != output_get_input_connection.GetOwningLayerGuid()
71
72 # Set TensorInfo
73 test_tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), ann.DataType_Float32)
74
75 # Check IsTensorInfoSet()
76 assert not add_get_input_connection.IsTensorInfoSet()
77 add_get_input_connection.SetTensorInfo(test_tensor_info)
78 assert add_get_input_connection.IsTensorInfoSet()
79
80 # Check GetTensorInfo()
81 output_tensor_info = add_get_input_connection.GetTensorInfo()
82 assert 2 == output_tensor_info.GetNumDimensions()
83 assert 6 == output_tensor_info.GetNumElements()
84
85 # Check Disconnect()
86 assert output_get_input_connection.GetNumConnections() == 1 # 1 connection to Outputslot0 from input1
87 add.GetOutputSlot(0).Disconnect(output.GetInputSlot(0)) # disconnect add.OutputSlot0 from Output.InputSlot0
88 assert output_get_input_connection.GetNumConnections() == 0
89
90 def test_output_slot__out_of_range(self, network):
91 # Create input layer to check output slot get item handling
92 input1 = network.AddInputLayer(0, "input1")
93
94 outputSlot = input1.GetOutputSlot(0)
95 with pytest.raises(ValueError) as err:
96 outputSlot[1]
97
98 assert "Invalid index 1 provided" in str(err.value)
99
100 def test_iconnectable_guid(self, network):
101
102 # Check IConnectable GetGuid()
103 # Note Guid can change based on which tests are run so
104 # checking here that each layer does not have the same guid
105 add_id = network.AddAdditionLayer().GetGuid()
106 output_id = network.AddOutputLayer(0).GetGuid()
107 assert add_id != output_id
108
109 def test_iconnectable_layer_functions(self, network):
110
111 # Create input, addition & output layer
112 input1 = network.AddInputLayer(0, "input1")
113 input2 = network.AddInputLayer(1, "input2")
114 add = network.AddAdditionLayer("addition")
115 output = network.AddOutputLayer(0, "output")
116
117 # Check GetNumInputSlots(), GetName() & GetNumOutputSlots()
118 assert input1.GetNumInputSlots() == 0
119 assert input1.GetName() == "input1"
120 assert input1.GetNumOutputSlots() == 1
121
122 assert input2.GetNumInputSlots() == 0
123 assert input2.GetName() == "input2"
124 assert input2.GetNumOutputSlots() == 1
125
126 assert add.GetNumInputSlots() == 2
127 assert add.GetName() == "addition"
128 assert add.GetNumOutputSlots() == 1
129
130 assert output.GetNumInputSlots() == 1
131 assert output.GetName() == "output"
132 assert output.GetNumOutputSlots() == 0
133
134 # Check GetOutputSlot()
135 input1_get_output = input1.GetOutputSlot(0)
136 assert input1_get_output.GetNumConnections() == 0
137 assert len(input1_get_output) == 0
138
139 # Check GetInputSlot()
140 add_get_input = add.GetInputSlot(0)
141 add_get_input.GetConnection()
142 assert isinstance(add_get_input, ann.IInputSlot)