blob: aeb7ab5fdd432ad37c535bb174a93a824a9a014a [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc0a87c12019-10-29 17:58:36 +00005
telsoa014fcda012018-03-09 14:13:49 +00006#include "armnn/Tensor.hpp"
7#include "armnn/Utils.hpp"
8#include "armnn/Exceptions.hpp"
9#include "armnn/TypesUtils.hpp"
10
11#include <boost/assert.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012#include <boost/numeric/conversion/cast.hpp>
13
Aron Virginas-Tar06e25c42019-02-21 15:45:03 +000014#include <sstream>
15
telsoa014fcda012018-03-09 14:13:49 +000016namespace armnn
17{
18
19// ---
20// --- TensorShape
21// ---
22
23TensorShape::TensorShape()
24 : m_NumDimensions(0)
25{
26}
27
Matteo Martincighf9afc792018-12-06 12:03:17 +000028TensorShape::TensorShape(unsigned int numDimensions)
29 : m_NumDimensions(numDimensions)
30{
31 if (numDimensions < 1)
32 {
33 throw InvalidArgumentException("Tensor numDimensions must be greater than 0");
34 }
35
36 if (numDimensions > MaxNumOfTensorDimensions)
37 {
38 throw InvalidArgumentException("Tensor numDimensions must be less than or equal to MaxNumOfTensorDimensions");
39 }
40
41 std::fill(m_Dimensions.begin(), m_Dimensions.begin() + m_NumDimensions, 0);
42}
43
telsoa014fcda012018-03-09 14:13:49 +000044TensorShape::TensorShape(const unsigned int numDimensions, const unsigned int* const dimensionSizes)
45 : m_NumDimensions(numDimensions)
46{
47 if (numDimensions < 1)
48 {
49 throw InvalidArgumentException("Tensor numDimensions must be greater than 0");
50 }
51
52 if (numDimensions > MaxNumOfTensorDimensions)
53 {
54 throw InvalidArgumentException("Tensor numDimensions must be less than or equal to MaxNumOfTensorDimensions");
55 }
56
57 if (dimensionSizes == nullptr)
58 {
59 throw InvalidArgumentException("Tensor dimensionSizes must not be NULL");
60 }
61
62 std::copy(dimensionSizes, dimensionSizes + numDimensions, m_Dimensions.begin());
63}
64
65TensorShape::TensorShape(std::initializer_list<unsigned int> dimensionSizeList)
66 : TensorShape(boost::numeric_cast<unsigned int>(dimensionSizeList.size()), dimensionSizeList.begin())
67{
68}
69
70TensorShape::TensorShape(const TensorShape& other)
71 : m_NumDimensions(other.m_NumDimensions)
72{
73 std::copy(other.m_Dimensions.cbegin(), other.m_Dimensions.cbegin() + other.m_NumDimensions, m_Dimensions.begin());
74}
75
76TensorShape& TensorShape::operator =(const TensorShape& other)
77{
78 m_NumDimensions = other.m_NumDimensions;
79 std::copy(other.m_Dimensions.cbegin(), other.m_Dimensions.cbegin() + other.m_NumDimensions, m_Dimensions.begin());
80 return *this;
81}
82
Aron Virginas-Tar06e25c42019-02-21 15:45:03 +000083unsigned int TensorShape::operator[](unsigned int i) const
84{
85 CheckDimensionIndex(i);
86 return m_Dimensions.at(i);
87}
88
89unsigned int& TensorShape::operator[](unsigned int i)
90{
91 CheckDimensionIndex(i);
92 return m_Dimensions.at(i);
93}
94
telsoa014fcda012018-03-09 14:13:49 +000095bool TensorShape::operator==(const TensorShape& other) const
96{
97 return ((m_NumDimensions == other.m_NumDimensions) &&
98 std::equal(m_Dimensions.cbegin(), m_Dimensions.cbegin() + m_NumDimensions, other.m_Dimensions.cbegin()));
99}
100
101bool TensorShape::operator!=(const TensorShape& other) const
102{
103 return !(*this == other);
104}
105
106unsigned int TensorShape::GetNumElements() const
107{
108 if (m_NumDimensions == 0)
109 {
110 return 0;
111 }
112
113 unsigned int count = 1;
114 for (unsigned int i = 0; i < m_NumDimensions; i++)
115 {
116 count *= m_Dimensions[i];
117 }
118
119 return count;
120}
121
Aron Virginas-Tar06e25c42019-02-21 15:45:03 +0000122void TensorShape::CheckDimensionIndex(unsigned int i) const
123{
124 if (i >= m_NumDimensions)
125 {
126 std::stringstream errorMessage;
127 errorMessage << "Invalid dimension index: " << i << " (number of dimensions is " << m_NumDimensions << ")";
128 throw InvalidArgumentException(errorMessage.str(), CHECK_LOCATION());
129 }
130}
131
telsoa014fcda012018-03-09 14:13:49 +0000132// ---
133// --- TensorInfo
134// ---
135
136TensorInfo::TensorInfo()
137: m_DataType(DataType::Float32)
138{
139}
140
Aron Virginas-Tarc0a87c12019-10-29 17:58:36 +0000141TensorInfo::TensorInfo(const TensorShape& shape,
142 DataType dataType,
143 float quantizationScale,
144 int32_t quantizationOffset)
145 : m_Shape(shape)
146 , m_DataType(dataType)
telsoa014fcda012018-03-09 14:13:49 +0000147{
Aron Virginas-Tarc0a87c12019-10-29 17:58:36 +0000148 SetQuantizationScale(quantizationScale);
149 SetQuantizationOffset(quantizationOffset);
telsoa014fcda012018-03-09 14:13:49 +0000150}
151
Aron Virginas-Tarc0a87c12019-10-29 17:58:36 +0000152TensorInfo::TensorInfo(unsigned int numDimensions,
153 const unsigned int* dimensionSizes,
154 DataType dataType,
155 float quantizationScale,
156 int32_t quantizationOffset)
157 : m_Shape(numDimensions, dimensionSizes)
158 , m_DataType(dataType)
telsoa014fcda012018-03-09 14:13:49 +0000159{
Aron Virginas-Tarc0a87c12019-10-29 17:58:36 +0000160 SetQuantizationScale(quantizationScale);
161 SetQuantizationOffset(quantizationOffset);
162}
163
164TensorInfo::TensorInfo(const TensorShape& shape,
165 DataType dataType,
166 const std::vector<float>& quantizationScales,
167 unsigned int quantizationDim)
168 : m_Shape(shape)
169 , m_DataType(dataType)
170{
171 SetQuantizationScales(quantizationScales);
172 SetQuantizationDim(MakeOptional<unsigned int>(quantizationDim));
173}
174
175TensorInfo::TensorInfo(unsigned int numDimensions,
176 const unsigned int* dimensionSizes,
177 DataType dataType,
178 const std::vector<float>& quantizationScales,
179 unsigned int quantizationDim)
180 : m_Shape(numDimensions, dimensionSizes)
181 , m_DataType(dataType)
182{
183 SetQuantizationScales(quantizationScales);
184 SetQuantizationDim(MakeOptional<unsigned int>(quantizationDim));
telsoa014fcda012018-03-09 14:13:49 +0000185}
186
187TensorInfo::TensorInfo(const TensorInfo& other)
188: m_Shape(other.m_Shape)
189, m_DataType(other.m_DataType)
190, m_Quantization(other.m_Quantization)
Aron Virginas-Tarc0a87c12019-10-29 17:58:36 +0000191{}
telsoa014fcda012018-03-09 14:13:49 +0000192
193TensorInfo& TensorInfo::operator=(const TensorInfo& other)
194{
195 m_Shape = other.m_Shape;
196 m_DataType = other.m_DataType;
197 m_Quantization = other.m_Quantization;
198 return *this;
199}
200
201bool TensorInfo::operator==(const TensorInfo& other) const
202{
203 return ((m_Shape == other.m_Shape) &&
204 (m_DataType == other.m_DataType) &&
205 (m_Quantization == other.m_Quantization));
206}
207
208bool TensorInfo::operator!=(const TensorInfo& other) const
209{
210 return !(*this == other);
211}
212
213unsigned int TensorInfo::GetNumBytes() const
214{
215 return GetDataTypeSize(m_DataType) * GetNumElements();
216}
217
Derek Lamberti0790dce2019-04-15 18:37:35 +0100218bool TensorInfo::IsTypeSpaceMatch(const TensorInfo& other) const
219{
220 bool match = true;
221
222 match &= m_DataType == other.m_DataType;
223
Aron Virginas-Tarc0a87c12019-10-29 17:58:36 +0000224 if (IsQuantized() && !HasMultipleQuantizationScales())
Derek Lamberti0790dce2019-04-15 18:37:35 +0100225 {
226 match &= GetQuantizationScale() == other.GetQuantizationScale() &&
227 GetQuantizationOffset() == other.GetQuantizationOffset();
228 }
229 return match;
230}
231
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000232bool TensorInfo::HasPerAxisQuantization() const
233{
234 return HasMultipleQuantizationScales() || m_Quantization.m_QuantizationDim.has_value();
235}
236
Aron Virginas-Tarc0a87c12019-10-29 17:58:36 +0000237std::vector<float> TensorInfo::GetQuantizationScales() const
238{
239 return m_Quantization.m_Scales;
240}
241
242void TensorInfo::SetQuantizationScales(const std::vector<float>& scales)
243{
244 m_Quantization.m_Scales = scales;
245}
246
247float TensorInfo::GetQuantizationScale() const
248{
249 if (m_Quantization.m_Scales.empty())
250 {
251 // NOTE: old default for backward compatibility
252 return 1.0f;
253 }
254
255 BOOST_ASSERT(!HasMultipleQuantizationScales());
256 return m_Quantization.m_Scales[0];
257}
258
259void TensorInfo::SetQuantizationScale(float scale)
260{
261 m_Quantization.m_Scales = { scale };
262}
263
264int32_t TensorInfo::GetQuantizationOffset() const
265{
266 if (!m_Quantization.m_Offset.has_value())
267 {
268 // NOTE: old default for backward compatibility
269 return 0;
270 }
271
272 return m_Quantization.m_Offset.value();
273}
274
275void TensorInfo::SetQuantizationOffset(int32_t offset)
276{
277 m_Quantization.m_Offset = MakeOptional<int32_t>(offset);
278}
279
280Optional<unsigned int> TensorInfo::GetQuantizationDim() const
281{
282 return m_Quantization.m_QuantizationDim;
283}
284
285void TensorInfo::SetQuantizationDim(const Optional<unsigned int>& quantizationDim)
286{
287 m_Quantization.m_QuantizationDim = quantizationDim;
288}
289
290bool TensorInfo::IsQuantized() const
291{
Derek Lambertid466a542020-01-22 15:37:29 +0000292 return IsQuantizedType(m_DataType);
Aron Virginas-Tarc0a87c12019-10-29 17:58:36 +0000293}
294
telsoa014fcda012018-03-09 14:13:49 +0000295// ---
296// --- BaseTensor
297// ---
298
299template<typename MemoryType>
300BaseTensor<MemoryType>::BaseTensor()
301 : m_MemoryArea(nullptr)
302{
303}
304
305template<typename MemoryType>
306BaseTensor<MemoryType>::BaseTensor(const TensorInfo& info, MemoryType memoryArea)
307 : m_MemoryArea(memoryArea)
308 , m_Info(info)
309{
310}
311
312template<typename MemoryType>
313BaseTensor<MemoryType>::BaseTensor(const BaseTensor<MemoryType>& other)
314 : m_MemoryArea(other.m_MemoryArea)
315 , m_Info(other.GetInfo())
316{
317}
318
319template<typename MemoryType>
320BaseTensor<MemoryType>& BaseTensor<MemoryType>::operator =(const BaseTensor<MemoryType>& other)
321{
322 m_Info = other.m_Info;
323 m_MemoryArea = other.m_MemoryArea;
324 return *this;
325}
326
telsoa01c577f2c2018-08-31 09:22:23 +0100327// Explicit instantiations.
telsoa014fcda012018-03-09 14:13:49 +0000328template class BaseTensor<const void*>;
329template class BaseTensor<void*>;
330
331} // namespace armnn