blob: 8e72d4694c1a912a1e7979f8bed580992bd2ba3f [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "armnn/Tensor.hpp"
6#include "armnn/Utils.hpp"
7#include "armnn/Exceptions.hpp"
8#include "armnn/TypesUtils.hpp"
9
10#include <boost/assert.hpp>
11#include <boost/log/trivial.hpp>
12#include <boost/numeric/conversion/cast.hpp>
13
14namespace armnn
15{
16
17// ---
18// --- TensorShape
19// ---
20
21TensorShape::TensorShape()
22 : m_NumDimensions(0)
23{
24}
25
26TensorShape::TensorShape(const unsigned int numDimensions, const unsigned int* const dimensionSizes)
27 : m_NumDimensions(numDimensions)
28{
29 if (numDimensions < 1)
30 {
31 throw InvalidArgumentException("Tensor numDimensions must be greater than 0");
32 }
33
34 if (numDimensions > MaxNumOfTensorDimensions)
35 {
36 throw InvalidArgumentException("Tensor numDimensions must be less than or equal to MaxNumOfTensorDimensions");
37 }
38
39 if (dimensionSizes == nullptr)
40 {
41 throw InvalidArgumentException("Tensor dimensionSizes must not be NULL");
42 }
43
44 std::copy(dimensionSizes, dimensionSizes + numDimensions, m_Dimensions.begin());
45}
46
47TensorShape::TensorShape(std::initializer_list<unsigned int> dimensionSizeList)
48 : TensorShape(boost::numeric_cast<unsigned int>(dimensionSizeList.size()), dimensionSizeList.begin())
49{
50}
51
52TensorShape::TensorShape(const TensorShape& other)
53 : m_NumDimensions(other.m_NumDimensions)
54{
55 std::copy(other.m_Dimensions.cbegin(), other.m_Dimensions.cbegin() + other.m_NumDimensions, m_Dimensions.begin());
56}
57
58TensorShape& TensorShape::operator =(const TensorShape& other)
59{
60 m_NumDimensions = other.m_NumDimensions;
61 std::copy(other.m_Dimensions.cbegin(), other.m_Dimensions.cbegin() + other.m_NumDimensions, m_Dimensions.begin());
62 return *this;
63}
64
65bool TensorShape::operator==(const TensorShape& other) const
66{
67 return ((m_NumDimensions == other.m_NumDimensions) &&
68 std::equal(m_Dimensions.cbegin(), m_Dimensions.cbegin() + m_NumDimensions, other.m_Dimensions.cbegin()));
69}
70
71bool TensorShape::operator!=(const TensorShape& other) const
72{
73 return !(*this == other);
74}
75
76unsigned int TensorShape::GetNumElements() const
77{
78 if (m_NumDimensions == 0)
79 {
80 return 0;
81 }
82
83 unsigned int count = 1;
84 for (unsigned int i = 0; i < m_NumDimensions; i++)
85 {
86 count *= m_Dimensions[i];
87 }
88
89 return count;
90}
91
92// ---
93// --- TensorInfo
94// ---
95
96TensorInfo::TensorInfo()
97: m_DataType(DataType::Float32)
98{
99}
100
101TensorInfo::TensorInfo(const TensorShape& shape, DataType dataType,
102 float quantizationScale, int32_t quantizationOffset)
103 : m_Shape(shape)
104 , m_DataType(dataType)
105{
106 m_Quantization.m_Scale = quantizationScale;
107 m_Quantization.m_Offset = quantizationOffset;
108}
109
110TensorInfo::TensorInfo(unsigned int numDimensions, const unsigned int* dimensionSizes, DataType dataType,
111 float quantizationScale, int32_t quantizationOffset)
112 : m_Shape(numDimensions, dimensionSizes)
113 , m_DataType(dataType)
114{
115 m_Quantization.m_Scale = quantizationScale;
116 m_Quantization.m_Offset = quantizationOffset;
117}
118
119TensorInfo::TensorInfo(const TensorInfo& other)
120: m_Shape(other.m_Shape)
121, m_DataType(other.m_DataType)
122, m_Quantization(other.m_Quantization)
123{
124}
125
126TensorInfo& TensorInfo::operator=(const TensorInfo& other)
127{
128 m_Shape = other.m_Shape;
129 m_DataType = other.m_DataType;
130 m_Quantization = other.m_Quantization;
131 return *this;
132}
133
134bool TensorInfo::operator==(const TensorInfo& other) const
135{
136 return ((m_Shape == other.m_Shape) &&
137 (m_DataType == other.m_DataType) &&
138 (m_Quantization == other.m_Quantization));
139}
140
141bool TensorInfo::operator!=(const TensorInfo& other) const
142{
143 return !(*this == other);
144}
145
146unsigned int TensorInfo::GetNumBytes() const
147{
148 return GetDataTypeSize(m_DataType) * GetNumElements();
149}
150
151// ---
152// --- BaseTensor
153// ---
154
155template<typename MemoryType>
156BaseTensor<MemoryType>::BaseTensor()
157 : m_MemoryArea(nullptr)
158{
159}
160
161template<typename MemoryType>
162BaseTensor<MemoryType>::BaseTensor(const TensorInfo& info, MemoryType memoryArea)
163 : m_MemoryArea(memoryArea)
164 , m_Info(info)
165{
166}
167
168template<typename MemoryType>
169BaseTensor<MemoryType>::BaseTensor(const BaseTensor<MemoryType>& other)
170 : m_MemoryArea(other.m_MemoryArea)
171 , m_Info(other.GetInfo())
172{
173}
174
175template<typename MemoryType>
176BaseTensor<MemoryType>& BaseTensor<MemoryType>::operator =(const BaseTensor<MemoryType>& other)
177{
178 m_Info = other.m_Info;
179 m_MemoryArea = other.m_MemoryArea;
180 return *this;
181}
182
telsoa01c577f2c2018-08-31 09:22:23 +0100183// Explicit instantiations.
telsoa014fcda012018-03-09 14:13:49 +0000184template class BaseTensor<const void*>;
185template class BaseTensor<void*>;
186
187} // namespace armnn