blob: 8b57169596d98cbd2f06de6bdb65aff659229a44 [file] [log] [blame]
Richard Burtondc0c6ed2020-04-08 16:39:05 +01001# Copyright © 2020 Arm Ltd. All rights reserved.
2# SPDX-License-Identifier: MIT
3from copy import copy
4
5import pytest
6import numpy as np
7import pyarmnn as ann
8
9
10def __get_tensor_info(dt):
11 tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), dt)
12
13 return tensor_info
14
15
16@pytest.mark.parametrize("dt", [ann.DataType_Float32, ann.DataType_Float16,
17 ann.DataType_QAsymmU8, ann.DataType_QSymmS8,
18 ann.DataType_QAsymmS8])
19def test_create_tensor_with_info(dt):
20 tensor_info = __get_tensor_info(dt)
21 elements = tensor_info.GetNumElements()
22 num_bytes = tensor_info.GetNumBytes()
23 d_type = dt
24
25 tensor = ann.Tensor(tensor_info)
26
27 assert tensor_info != tensor.GetInfo(), "Different objects"
28 assert elements == tensor.GetNumElements()
29 assert num_bytes == tensor.GetNumBytes()
30 assert d_type == tensor.GetDataType()
31
32
33def test_create_tensor_undefined_datatype():
34 tensor_info = ann.TensorInfo()
35 tensor_info.SetDataType(99)
36
37 with pytest.raises(ValueError) as err:
38 ann.Tensor(tensor_info)
39
40 assert 'The data type provided for this Tensor is not supported.' in str(err.value)
41
42
43@pytest.mark.parametrize("dt", [ann.DataType_Float32])
44def test_tensor_memory_output(dt):
45 tensor_info = __get_tensor_info(dt)
46 tensor = ann.Tensor(tensor_info)
47
48 # empty memory area because inference has not yet been run.
49 assert tensor.get_memory_area().tolist() # has random stuff
50 assert 4 == tensor.get_memory_area().itemsize, "it is float32"
51
52
53@pytest.mark.parametrize("dt", [ann.DataType_Float32, ann.DataType_Float16,
54 ann.DataType_QAsymmU8, ann.DataType_QSymmS8,
55 ann.DataType_QAsymmS8])
56def test_tensor__str__(dt):
57 tensor_info = __get_tensor_info(dt)
58 elements = tensor_info.GetNumElements()
59 num_bytes = tensor_info.GetNumBytes()
60 d_type = dt
61 dimensions = tensor_info.GetNumDimensions()
62
63 tensor = ann.Tensor(tensor_info)
64
65 assert str(tensor) == "Tensor{{DataType: {}, NumBytes: {}, NumDimensions: " \
66 "{}, NumElements: {}}}".format(d_type, num_bytes, dimensions, elements)
67
68
69def test_create_empty_tensor():
70 tensor = ann.Tensor()
71
72 assert 0 == tensor.GetNumElements()
73 assert 0 == tensor.GetNumBytes()
74 assert tensor.get_memory_area() is None
75
76
77@pytest.mark.parametrize("dt", [ann.DataType_Float32, ann.DataType_Float16,
78 ann.DataType_QAsymmU8, ann.DataType_QSymmS8,
79 ann.DataType_QAsymmS8])
80def test_create_tensor_from_tensor(dt):
81 tensor_info = __get_tensor_info(dt)
82 tensor = ann.Tensor(tensor_info)
83 copied_tensor = ann.Tensor(tensor)
84
85 assert copied_tensor != tensor, "Different objects"
86 assert copied_tensor.GetInfo() != tensor.GetInfo(), "Different objects"
87 assert copied_tensor.get_memory_area().ctypes.data == tensor.get_memory_area().ctypes.data, "Same memory area"
88 assert copied_tensor.GetNumElements() == tensor.GetNumElements()
89 assert copied_tensor.GetNumBytes() == tensor.GetNumBytes()
90 assert copied_tensor.GetDataType() == tensor.GetDataType()
91
92
93@pytest.mark.parametrize("dt", [ann.DataType_Float32, ann.DataType_Float16,
94 ann.DataType_QAsymmU8, ann.DataType_QSymmS8,
95 ann.DataType_QAsymmS8])
96def test_copy_tensor(dt):
97 tensor = ann.Tensor(__get_tensor_info(dt))
98 copied_tensor = copy(tensor)
99
100 assert copied_tensor != tensor, "Different objects"
101 assert copied_tensor.GetInfo() != tensor.GetInfo(), "Different objects"
102 assert copied_tensor.get_memory_area().ctypes.data == tensor.get_memory_area().ctypes.data, "Same memory area"
103 assert copied_tensor.GetNumElements() == tensor.GetNumElements()
104 assert copied_tensor.GetNumBytes() == tensor.GetNumBytes()
105 assert copied_tensor.GetDataType() == tensor.GetDataType()
106
107
108@pytest.mark.parametrize("dt", [ann.DataType_Float32, ann.DataType_Float16,
109 ann.DataType_QAsymmU8, ann.DataType_QSymmS8,
110 ann.DataType_QAsymmS8])
111def test_copied_tensor_has_memory_area_access_after_deletion_of_original_tensor(dt):
112
113 tensor = ann.Tensor(__get_tensor_info(dt))
114
115 tensor.get_memory_area()[0] = 100
116
117 initial_mem_copy = np.array(tensor.get_memory_area())
118
119 assert 100 == initial_mem_copy[0]
120
121 copied_tensor = ann.Tensor(tensor)
122
123 del tensor
124 np.testing.assert_array_equal(copied_tensor.get_memory_area(), initial_mem_copy)
125 assert 100 == copied_tensor.get_memory_area()[0]
126
127
128def test_create_const_tensor_incorrect_args():
129 with pytest.raises(ValueError) as err:
130 ann.Tensor('something', 'something')
131
132 expected_error_message = "Incorrect number of arguments or type of arguments provided to create Tensor."
133 assert expected_error_message in str(err.value)
134
135
136@pytest.mark.parametrize("dt", [ann.DataType_Float16])
137def test_tensor_memory_output_fp16(dt):
138 # Check Tensor with float16
139 tensor_info = __get_tensor_info(dt)
140 tensor = ann.Tensor(tensor_info)
141
142 assert tensor.GetNumElements() == 6
143 assert tensor.GetNumBytes() == 12
144 assert tensor.GetDataType() == ann.DataType_Float16