blob: 892b8e4f0849c14af20b6dab1051231804c0113e [file] [log] [blame]
Richard Burtondc0c6ed2020-04-08 16:39:05 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5%{
6#include "armnn/Tensor.hpp"
7%}
8
9%include <typemaps/tensor_memory.i>
10%include <typemaps/tensor_shape.i>
11
12namespace armnn
13{
14
15%feature("docstring",
16"
17Class for holding the shape information of an Arm NN tensor.
18
19This class is iterable. You can iterate over it to get each value of the Tensor shape.
20
21Examples:
22 Obtain tensor shape information as a list.
23 >>> import pyarmnn as ann
24 >>> import numpy as np
25 >>>
26 >>> tensor_info = ann.TensorInfo(ann.TensorShape((4, 2, 1, 3)), ann.DataType_Float32)
27 >>> tensor = ann.ConstTensor(tensor_info, np.ones([4, 2, 1, 3], dtype=np.float32))
28 >>> print(list(tensor.GetShape()))
29 [4, 2, 1, 3]
30
31") TensorShape;
32class TensorShape
33{
34 // Make TensorShape iterable so we can return shape dims easily.
35 %pythoncode %{
36 def __iter__(self):
37 for dim in range(self.GetNumDimensions()):
38 yield self[dim]
39 %}
40
41
42public:
43 %tensor_shape_typemap(unsigned int numDimensions, const unsigned int* dimensionSizes);
44 TensorShape(unsigned int numDimensions, const unsigned int* dimensionSizes);
45 %clear_tensor_shape_typemap(unsigned int numDimensions, const unsigned int* dimensionSizes);
46
47 %feature("docstring",
48 "
49 Returns the number of dimensions in this TensorShape.
50
51 Returns:
52 int: The number of dimensions in this TensorShape.
53
54 ") GetNumDimensions;
55 unsigned int GetNumDimensions() const;
56
57 %feature("docstring",
58 "
59 Returns the total number of elements for a tensor with this TensorShape.
60
61 Returns:
62 int: The total number of elements for a tensor with this TensorShape.
63
64 ") GetNumElements;
65 unsigned int GetNumElements() const;
66
67};
68
69%extend TensorShape {
70
71 unsigned int __getitem__(unsigned int i) const {
72 return $self->operator[](i);
73 }
74 void __setitem__(unsigned int i, unsigned int val) {
75 $self->operator[](i) = val;
76 }
77
78 std::string __str__() {
79 std::string dim = "NumDimensions: " + std::to_string($self->GetNumDimensions());
80 std::string elm = "NumElements: " + std::to_string($self->GetNumElements());
81
82 std::string shapeStr = "TensorShape{Shape(";
83
84 auto numDimensions = $self->GetNumDimensions();
85 auto sizeDims = $self->GetNumDimensions();
86 for (unsigned int i = 0; i < numDimensions; i++) {
87 shapeStr += std::to_string($self->operator[](i));
88
89 if (sizeDims - 1 > 0) {
90 shapeStr += ", ";
91 }
92 sizeDims--;
93 }
94 shapeStr = shapeStr + "), " + dim + ", " + elm + "}";
95 return shapeStr;
96 }
97
98}
99
100
101%feature("docstring",
102"
103Class for holding the tensor information of an Arm NN tensor such as quantization, datatype, shape etc.
104
105") TensorInfo;
106class TensorInfo
107{
108public:
109 TensorInfo();
110
111 TensorInfo(const TensorInfo& other);
112
113 TensorInfo(const TensorShape& shape, DataType dataType,
Cathal Corbett5b8093c2021-10-22 11:12:07 +0100114 float quantizationScale = 0.0f, int32_t quantizationOffset = 0,
115 bool isConstant = False);
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100116
117 %feature("docstring",
118 "
119 Get the tensor shape.
120
121 Return:
122 TensorShape: Current shape of the tensor.
123
124 ") GetShape;
125 TensorShape& GetShape();
126
127 %feature("docstring",
128 "
129 Set the tensor shape. Must have the same number of elements as current tensor.
130
131 Args:
132 newShape (TensorShape): New tensor shape to reshape to.
133
134 ") SetShape;
135 void SetShape(const TensorShape& newShape);
136
137 %feature("docstring",
138 "
139 Returns the number of dimensions in this Tensor.
140
141 Returns:
142 int: The number of dimensions in this Tensor.
143
144 ") GetNumDimensions;
145 unsigned int GetNumDimensions() const;
146
147 %feature("docstring",
148 "
149 Returns the total number of elements for this Tensor.
150
151 Returns:
152 int: The total number of elements for this Tensor.
153
154 ") GetNumElements;
155 unsigned int GetNumElements() const;
156
157 %feature("docstring",
158 "
159 Get the tensor datatype.
160
161 Returns:
162 DataType: Current tensor DataType.
163
164 ") GetDataType;
165 DataType GetDataType() const;
166
167 %feature("docstring",
168 "
169 Set the tensor datatype.
170
171 Args:
172 type (DataType): DataType to set the tensor to.
173
174 ") SetDataType;
175 void SetDataType(DataType type);
176
177 %feature("docstring",
178 "
179 Get the value of the tensors quantization scale.
180
181 Returns:
182 float: Tensor quantization scale value.
183
184 ") GetQuantizationScale;
185 float GetQuantizationScale() const;
186
187 %feature("docstring",
188 "
189 Get the value of the tensors quantization offset.
190
191 Returns:
192 int: Tensor quantization offset value.
193
194 ") GetQuantizationOffset;
195 int32_t GetQuantizationOffset() const;
196
197 %feature("docstring",
198 "
199 Set the value of the tensors quantization scale.
200
201 Args:
202 scale (float): Scale value to set.
203
204 ") SetQuantizationScale;
205 void SetQuantizationScale(float scale);
206
207 %feature("docstring",
208 "
209 Set the value of the tensors quantization offset.
210
211 Args:
212 offset (int): Offset value to set.
213
214 ") SetQuantizationOffset;
215 void SetQuantizationOffset(int32_t offset);
216
217 %feature("docstring",
218 "
219 Returns true if the tensor is a quantized data type.
220
221 Returns:
222 bool: True if the tensor is a quantized data type.
223
224 ") IsQuantized;
225 bool IsQuantized() const;
226
Cathal Corbett5b8093c2021-10-22 11:12:07 +0100227 %feature("docstring",
228 "
229 Returns true if the tensor info is constant.
230
231 Returns:
232 bool: True if the tensor info is constant.
233
234 ") IsConstant;
235 bool IsConstant() const;
236
237 %feature("docstring",
238 "
239 Sets the tensor info to be constant.
Cathal Corbett019da942021-11-10 12:50:57 +0000240
Cathal Corbett5b8093c2021-10-22 11:12:07 +0100241 Args:
242 IsConstant (bool): Sets tensor info to constant.
243
244 ") SetConstant;
245 void SetConstant(const bool IsConstant = True);
246
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100247
248
249 %feature("docstring",
250 "
251 Check that the types are the same and, if quantize, that the quantization parameters are the same.
252
253 Returns:
254 bool: True if matched, else False.
255
256 ") IsTypeSpaceMatch;
257 bool IsTypeSpaceMatch(const TensorInfo& other) const;
258
259 %feature("docstring",
260 "
261 Get the number of bytes needed for this tensor.
262
263 Returns:
264 int: Number of bytes consumed by this tensor.
265
266 ") GetNumBytes;
267 unsigned int GetNumBytes() const;
268
269};
270
271%extend TensorInfo {
272
273 std::string __str__() {
274 const std::string tmp = "TensorInfo{DataType: " + std::to_string(static_cast<int>($self->GetDataType()))
275 + ", IsQuantized: " + std::to_string($self->IsQuantized())
276 + ", QuantizationScale: " + std::to_string( $self->GetQuantizationScale())
277 + ", QuantizationOffset: " + std::to_string($self->GetQuantizationOffset())
Cathal Corbett5b8093c2021-10-22 11:12:07 +0100278 + ", IsConstant: " + std::to_string($self->IsConstant())
Richard Burtondc0c6ed2020-04-08 16:39:05 +0100279 + ", NumDimensions: " + std::to_string($self->GetNumDimensions())
280 + ", NumElements: " + std::to_string($self->GetNumElements()) + "}";
281 return tmp;
282 }
283
284}
285
286class Tensor
287{
288public:
289 ~Tensor();
290 Tensor();
291 Tensor(const Tensor& other);
292
293 %mutable_memory(void* memory);
294 Tensor(const TensorInfo& info, void* memory);
295 %clear_mutable_memory(void* memory);
296
297 const TensorInfo& GetInfo() const;
298 const TensorShape& GetShape() const;
299
300 DataType GetDataType() const;
301 unsigned int GetNumDimensions() const;
302 unsigned int GetNumBytes() const;
303 unsigned int GetNumElements() const;
304
305 /* we want to disable getting the memory area from here - forcing use of get_memory_area() in public api.
306 void* GetMemoryArea() const;*/
307};
308
309%extend Tensor {
310
311 std::string __str__() {
312 const std::string tmp = "Tensor{DataType: " + std::to_string(static_cast<int>($self->GetDataType()))
313 + ", NumBytes: " + std::to_string($self->GetNumBytes())
314 + ", NumDimensions: " + std::to_string( $self->GetNumDimensions())
315 + ", NumElements: " + std::to_string($self->GetNumElements()) + "}";
316 return tmp;
317 }
318}
319
320class ConstTensor
321{
322public:
323 ~ConstTensor();
324 ConstTensor();
325 ConstTensor(const Tensor& other);
326 ConstTensor(const ConstTensor& other);
327
328 %const_memory(const void* memory);
329 ConstTensor(const TensorInfo& info, const void* memory);
330 %clear_const_memory(const void* memory);
331
332 const TensorInfo& GetInfo() const;
333 const TensorShape& GetShape() const;
334
335 DataType GetDataType() const;
336 unsigned int GetNumDimensions() const;
337 unsigned int GetNumBytes() const;
338 unsigned int GetNumElements() const;
339
340 /* we want to disable getting the memory area from here - forcing use of get_memory_area() in public api.
341 void* GetMemoryArea() const;*/
342};
343
344%extend ConstTensor {
345
346 std::string __str__() {
347 const std::string tmp = "ConstTensor{DataType: " + std::to_string(static_cast<int>($self->GetDataType()))
348 + ", NumBytes: " + std::to_string($self->GetNumBytes())
349 + ", NumDimensions: " + std::to_string( $self->GetNumDimensions())
350 + ", NumElements: " + std::to_string($self->GetNumElements()) + "}";
351 return tmp;
352 }
353}
354
355}