Richard Burton | dc0c6ed | 2020-04-08 16:39:05 +0100 | [diff] [blame] | 1 | # Copyright © 2020 Arm Ltd. All rights reserved. |
| 2 | # SPDX-License-Identifier: MIT |
| 3 | import pytest |
| 4 | |
| 5 | import pyarmnn as ann |
| 6 | |
| 7 | |
| 8 | @pytest.fixture(scope="function") |
| 9 | def network(): |
| 10 | return ann.INetwork() |
| 11 | |
| 12 | |
| 13 | class TestIInputIOutputIConnectable: |
| 14 | |
| 15 | def test_input_slot(self, network): |
| 16 | # Create input, addition & output layer |
| 17 | input1 = network.AddInputLayer(0, "input1") |
| 18 | input2 = network.AddInputLayer(1, "input2") |
| 19 | add = network.AddAdditionLayer("addition") |
| 20 | output = network.AddOutputLayer(0, "output") |
| 21 | |
| 22 | # Connect the input/output slots for each layer |
| 23 | input1.GetOutputSlot(0).Connect(add.GetInputSlot(0)) |
| 24 | input2.GetOutputSlot(0).Connect(add.GetInputSlot(1)) |
| 25 | add.GetOutputSlot(0).Connect(output.GetInputSlot(0)) |
| 26 | |
| 27 | # Check IInputSlot GetConnection() |
| 28 | input_slot = add.GetInputSlot(0) |
| 29 | input_slot_connection = input_slot.GetConnection() |
| 30 | |
| 31 | assert isinstance(input_slot_connection, ann.IOutputSlot) |
| 32 | |
| 33 | del input_slot_connection |
| 34 | |
| 35 | assert input_slot.GetConnection() |
| 36 | assert isinstance(input_slot.GetConnection(), ann.IOutputSlot) |
| 37 | |
| 38 | del input_slot |
| 39 | |
| 40 | assert add.GetInputSlot(0) |
| 41 | |
| 42 | def test_output_slot(self, network): |
| 43 | |
| 44 | # Create input, addition & output layer |
| 45 | input1 = network.AddInputLayer(0, "input1") |
| 46 | input2 = network.AddInputLayer(1, "input2") |
| 47 | add = network.AddAdditionLayer("addition") |
| 48 | output = network.AddOutputLayer(0, "output") |
| 49 | |
| 50 | # Connect the input/output slots for each layer |
| 51 | input1.GetOutputSlot(0).Connect(add.GetInputSlot(0)) |
| 52 | input2.GetOutputSlot(0).Connect(add.GetInputSlot(1)) |
| 53 | add.GetOutputSlot(0).Connect(output.GetInputSlot(0)) |
| 54 | |
| 55 | # Check IInputSlot GetConnection() |
| 56 | add_get_input_connection = add.GetInputSlot(0).GetConnection() |
| 57 | output_get_input_connection = output.GetInputSlot(0).GetConnection() |
| 58 | |
| 59 | # Check IOutputSlot GetConnection() |
| 60 | add_get_output_connect = add.GetOutputSlot(0).GetConnection(0) |
| 61 | assert isinstance(add_get_output_connect.GetConnection(), ann.IOutputSlot) |
| 62 | |
| 63 | # Test IOutputSlot GetNumConnections() & CalculateIndexOnOwner() |
| 64 | assert add_get_input_connection.GetNumConnections() == 1 |
| 65 | assert len(add_get_input_connection) == 1 |
| 66 | assert add_get_input_connection[0] |
| 67 | assert add_get_input_connection.CalculateIndexOnOwner() == 0 |
| 68 | |
| 69 | # Check GetOwningLayerGuid(). Check that it is different for add and output layer |
| 70 | assert add_get_input_connection.GetOwningLayerGuid() != output_get_input_connection.GetOwningLayerGuid() |
| 71 | |
| 72 | # Set TensorInfo |
| 73 | test_tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), ann.DataType_Float32) |
| 74 | |
| 75 | # Check IsTensorInfoSet() |
| 76 | assert not add_get_input_connection.IsTensorInfoSet() |
| 77 | add_get_input_connection.SetTensorInfo(test_tensor_info) |
| 78 | assert add_get_input_connection.IsTensorInfoSet() |
| 79 | |
| 80 | # Check GetTensorInfo() |
| 81 | output_tensor_info = add_get_input_connection.GetTensorInfo() |
| 82 | assert 2 == output_tensor_info.GetNumDimensions() |
| 83 | assert 6 == output_tensor_info.GetNumElements() |
| 84 | |
| 85 | # Check Disconnect() |
| 86 | assert output_get_input_connection.GetNumConnections() == 1 # 1 connection to Outputslot0 from input1 |
| 87 | add.GetOutputSlot(0).Disconnect(output.GetInputSlot(0)) # disconnect add.OutputSlot0 from Output.InputSlot0 |
| 88 | assert output_get_input_connection.GetNumConnections() == 0 |
| 89 | |
| 90 | def test_output_slot__out_of_range(self, network): |
| 91 | # Create input layer to check output slot get item handling |
| 92 | input1 = network.AddInputLayer(0, "input1") |
| 93 | |
| 94 | outputSlot = input1.GetOutputSlot(0) |
| 95 | with pytest.raises(ValueError) as err: |
| 96 | outputSlot[1] |
| 97 | |
| 98 | assert "Invalid index 1 provided" in str(err.value) |
| 99 | |
| 100 | def test_iconnectable_guid(self, network): |
| 101 | |
| 102 | # Check IConnectable GetGuid() |
| 103 | # Note Guid can change based on which tests are run so |
| 104 | # checking here that each layer does not have the same guid |
| 105 | add_id = network.AddAdditionLayer().GetGuid() |
| 106 | output_id = network.AddOutputLayer(0).GetGuid() |
| 107 | assert add_id != output_id |
| 108 | |
| 109 | def test_iconnectable_layer_functions(self, network): |
| 110 | |
| 111 | # Create input, addition & output layer |
| 112 | input1 = network.AddInputLayer(0, "input1") |
| 113 | input2 = network.AddInputLayer(1, "input2") |
| 114 | add = network.AddAdditionLayer("addition") |
| 115 | output = network.AddOutputLayer(0, "output") |
| 116 | |
| 117 | # Check GetNumInputSlots(), GetName() & GetNumOutputSlots() |
| 118 | assert input1.GetNumInputSlots() == 0 |
| 119 | assert input1.GetName() == "input1" |
| 120 | assert input1.GetNumOutputSlots() == 1 |
| 121 | |
| 122 | assert input2.GetNumInputSlots() == 0 |
| 123 | assert input2.GetName() == "input2" |
| 124 | assert input2.GetNumOutputSlots() == 1 |
| 125 | |
| 126 | assert add.GetNumInputSlots() == 2 |
| 127 | assert add.GetName() == "addition" |
| 128 | assert add.GetNumOutputSlots() == 1 |
| 129 | |
| 130 | assert output.GetNumInputSlots() == 1 |
| 131 | assert output.GetName() == "output" |
| 132 | assert output.GetNumOutputSlots() == 0 |
| 133 | |
| 134 | # Check GetOutputSlot() |
| 135 | input1_get_output = input1.GetOutputSlot(0) |
| 136 | assert input1_get_output.GetNumConnections() == 0 |
| 137 | assert len(input1_get_output) == 0 |
| 138 | |
| 139 | # Check GetInputSlot() |
| 140 | add_get_input = add.GetInputSlot(0) |
| 141 | add_get_input.GetConnection() |
| 142 | assert isinstance(add_get_input, ann.IInputSlot) |