blob: ab68097247088ab6d747192746eace852f47f6d7 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "armnn/Descriptors.hpp"
Derek Lamberti08446972019-11-26 16:38:31 +00006#include "armnn/Logging.hpp"
telsoa014fcda012018-03-09 14:13:49 +00007
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01008#include <armnn/utility/Assert.hpp>
Matthew Sloyan0663d662020-09-14 11:47:26 +01009#include <armnn/utility/NumericCast.hpp>
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010
telsoa014fcda012018-03-09 14:13:49 +000011#include <algorithm>
12#include <array>
13#include <vector>
14
Matthew Sloyanf290d882020-10-12 15:03:01 +010015#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000016
17namespace armnn
18{
19
20PermutationVector::PermutationVector(const ValueType *dimMappings, const SizeType numDimMappings)
21{
22 // Validation
23
24 if (numDimMappings > MaxNumOfTensorDimensions)
25 {
Matthew Sloyanf290d882020-10-12 15:03:01 +010026 throw InvalidArgumentException(
27 fmt::format("The number of mappings ({0}) cannot be greater "
28 "than the maximum number of dimensions supported ({1})",
29 numDimMappings,
30 MaxNumOfTensorDimensions));
telsoa014fcda012018-03-09 14:13:49 +000031 }
32
33 if ((dimMappings == nullptr) && (numDimMappings != 0))
34 {
35 throw InvalidArgumentException("Dimension mappings must not be NULL if the number of mappings is positive");
36 }
37
38 for (SizeType i = 0; i < numDimMappings; ++i)
39 {
40 const ValueType dstIndex = dimMappings[i];
41 if (dstIndex >= numDimMappings)
42 {
Matthew Sloyanf290d882020-10-12 15:03:01 +010043 throw InvalidArgumentException(
44 fmt::format("Dimension mapping at index {0} is invalid: "
45 "{1} is outside of the valid range [0,{2}]",
46 i,
47 dstIndex,
48 (numDimMappings - 1)));
telsoa014fcda012018-03-09 14:13:49 +000049 }
50 }
51
52 // Validation: Detect duplicates
53 {
54 std::array<bool, MaxNumOfTensorDimensions> observedDims;
55 observedDims.fill(false);
56
57 for (SizeType i = 0; i < numDimMappings; ++i)
58 {
59 const ValueType dstIndex = dimMappings[i];
60 if (observedDims[dstIndex])
61 {
62 throw InvalidArgumentException("Invalid dimension mappings: Two or more source dimensions are mapped "
63 "to the same output dimension");
64 }
65 observedDims[dstIndex] = true;
66 }
67 }
68
69 // Initialize
70 for (SizeType i = 0; i < numDimMappings; ++i)
71 {
72 m_DimMappings[i] = dimMappings[i];
73 }
74 m_NumDimMappings = numDimMappings;
75}
76
77PermutationVector::PermutationVector(std::initializer_list<ValueType> dimMappings)
Matthew Sloyan0663d662020-09-14 11:47:26 +010078 : PermutationVector(dimMappings.begin(), armnn::numeric_cast<SizeType>(dimMappings.size()))
telsoa014fcda012018-03-09 14:13:49 +000079{
80}
81
82OriginsDescriptor::OriginsDescriptor()
Nikhil Raj8599a412018-11-19 14:51:07 +000083: m_ConcatAxis(1)
84, m_NumViews(0)
telsoa014fcda012018-03-09 14:13:49 +000085, m_NumDimensions(0)
86, m_ViewOrigins(nullptr)
87{}
88
89OriginsDescriptor::OriginsDescriptor(uint32_t numViews, uint32_t numDimensions /*= 4*/)
Nikhil Raj8599a412018-11-19 14:51:07 +000090: m_ConcatAxis(1)
91, m_NumViews(numViews)
telsoa014fcda012018-03-09 14:13:49 +000092, m_NumDimensions(numDimensions)
93, m_ViewOrigins(numViews && numDimensions > 0 ? new uint32_t *[numViews]() : nullptr)
94{
95 for (uint32_t i = 0; m_NumDimensions > 0 && i < m_NumViews; ++i)
96 {
97 m_ViewOrigins[i] = new uint32_t[m_NumDimensions]();
98 }
99}
100
101OriginsDescriptor::OriginsDescriptor(const OriginsDescriptor& other)
Nikhil Raj8599a412018-11-19 14:51:07 +0000102: m_ConcatAxis(other.m_ConcatAxis)
103, m_NumViews(other.m_NumViews)
telsoa014fcda012018-03-09 14:13:49 +0000104, m_NumDimensions(other.m_NumDimensions)
105, m_ViewOrigins(other.m_NumViews && other.m_NumDimensions > 0 ? new uint32_t *[other.m_NumViews]() : nullptr)
106{
107 for (uint32_t i = 0; m_NumDimensions > 0 && i < m_NumViews; ++i)
108 {
109 m_ViewOrigins[i] = new uint32_t[m_NumDimensions]();
110 memcpy(m_ViewOrigins[i], other.m_ViewOrigins[i], m_NumDimensions * sizeof(uint32_t));
111 }
112}
113
114OriginsDescriptor::OriginsDescriptor(OriginsDescriptor&& other)
115: OriginsDescriptor()
116{
117 swap(*this, other);
118}
119
120OriginsDescriptor::~OriginsDescriptor()
121{
122 for (uint32_t i = 0; m_NumDimensions > 0 && i < m_NumViews; ++i)
123 {
124 delete[] m_ViewOrigins[i];
125 }
126 delete[] m_ViewOrigins;
127}
128
129OriginsDescriptor& OriginsDescriptor::operator=(OriginsDescriptor rhs)
130{
131 swap(*this, rhs);
132 return *this;
133}
134
Aron Virginas-Tar6fe52472019-10-15 17:35:36 +0100135bool OriginsDescriptor::operator==(const OriginsDescriptor& rhs) const
136{
137 if (GetNumViews() != rhs.GetNumViews() ||
138 GetNumDimensions() != rhs.GetNumDimensions() ||
139 GetConcatAxis() != rhs.GetConcatAxis())
140 {
141 return false;
142 }
143
144 for (unsigned int i = 0u; i < GetNumViews(); ++i)
145 {
146 for (unsigned int j = 0u; j < GetNumDimensions(); ++j)
147 {
148 if (GetViewOrigin(i)[j] != rhs.GetViewOrigin(i)[j])
149 {
150 return false;
151 }
152 }
153 }
154
155 return true;
156}
157
Nikhil Raj8599a412018-11-19 14:51:07 +0000158void OriginsDescriptor::SetConcatAxis(unsigned int concatAxis)
159{
160 m_ConcatAxis = concatAxis;
161}
162unsigned int OriginsDescriptor::GetConcatAxis() const
163{
164 return m_ConcatAxis;
165}
166
telsoa014fcda012018-03-09 14:13:49 +0000167Status OriginsDescriptor::SetViewOriginCoord(uint32_t view, uint32_t coord, uint32_t value)
168{
169 if (view >= m_NumViews)
170 {
Derek Lamberti08446972019-11-26 16:38:31 +0000171 ARMNN_LOG(error) << "OriginsDescriptor::SetViewOriginCoord: view argument:" << view <<
telsoa014fcda012018-03-09 14:13:49 +0000172 " is out of range";
173 return Status::Failure;
174 }
175 if (coord >= m_NumDimensions)
176 {
Derek Lamberti08446972019-11-26 16:38:31 +0000177 ARMNN_LOG(error) << "OriginsDescriptor::SetViewOriginCoord: coord argument:" << coord <<
telsoa014fcda012018-03-09 14:13:49 +0000178 " is out of range";
179 return Status::Failure;
180 }
181
182 m_ViewOrigins[view][coord] = value;
183 return Status::Success;
184}
185
186
187uint32_t OriginsDescriptor::GetNumViews() const
188{
189 return m_NumViews;
190}
191
192uint32_t OriginsDescriptor::GetNumDimensions() const
193{
194 return m_NumDimensions;
195}
196
197const uint32_t* OriginsDescriptor::GetViewOrigin(uint32_t idx) const
198{
199 return m_ViewOrigins ? m_ViewOrigins[idx] : nullptr;
200}
201
202
telsoa01c577f2c2018-08-31 09:22:23 +0100203// Reorders the viewOrigins in accordance with the indices presented in newOrdering array.
telsoa014fcda012018-03-09 14:13:49 +0000204void OriginsDescriptor::ReorderOrigins(unsigned int* newOrdering, unsigned int numNewOrdering)
205{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100206 ARMNN_ASSERT_MSG(m_NumViews == numNewOrdering, "number of views must match number of "
telsoa014fcda012018-03-09 14:13:49 +0000207 "elements in the new ordering array");
208 std::vector<uint32_t*> viewOrigins(&m_ViewOrigins[0], &m_ViewOrigins[m_NumViews]);
209
210 for (unsigned int i = 0; i < numNewOrdering; ++i)
211 {
212 m_ViewOrigins[i] = viewOrigins[newOrdering[i]];
213 }
214}
215
216ViewsDescriptor::ViewsDescriptor()
217: m_Origins()
218, m_ViewSizes(nullptr)
219{}
220
221ViewsDescriptor::ViewsDescriptor(uint32_t numViews, uint32_t numDimensions /*= 4*/)
222 : m_Origins(numViews, numDimensions)
surmeh013537c2c2018-05-18 16:31:43 +0100223 , m_ViewSizes(numViews > 0 && numDimensions > 0 ?
224 new uint32_t *[numViews]() : nullptr)
telsoa014fcda012018-03-09 14:13:49 +0000225{
surmeh013537c2c2018-05-18 16:31:43 +0100226 if (m_ViewSizes)
telsoa014fcda012018-03-09 14:13:49 +0000227 {
surmeh013537c2c2018-05-18 16:31:43 +0100228 for (uint32_t i = 0; GetNumDimensions() > 0 && i < GetNumViews(); ++i)
229 {
230 m_ViewSizes[i] = new uint32_t[GetNumDimensions()]();
231 }
telsoa014fcda012018-03-09 14:13:49 +0000232 }
233}
234
235ViewsDescriptor::ViewsDescriptor(const ViewsDescriptor& other)
236 : m_Origins(other.m_Origins)
surmeh013537c2c2018-05-18 16:31:43 +0100237 , m_ViewSizes(other.GetNumViews() > 0 && other.GetNumDimensions() > 0 ?
238 new uint32_t *[other.GetNumViews()]() : nullptr)
telsoa014fcda012018-03-09 14:13:49 +0000239{
surmeh013537c2c2018-05-18 16:31:43 +0100240 if (m_ViewSizes)
telsoa014fcda012018-03-09 14:13:49 +0000241 {
surmeh013537c2c2018-05-18 16:31:43 +0100242 for (uint32_t i = 0; GetNumDimensions() > 0 && i < GetNumViews(); ++i)
243 {
244 m_ViewSizes[i] = new uint32_t[GetNumDimensions()]();
245 memcpy(m_ViewSizes[i], other.m_ViewSizes[i], GetNumDimensions() * sizeof(uint32_t));
246 }
telsoa014fcda012018-03-09 14:13:49 +0000247 }
248}
249
250ViewsDescriptor::ViewsDescriptor(ViewsDescriptor&& other)
251 : ViewsDescriptor()
252{
253 swap(*this, other);
254}
255
256ViewsDescriptor::~ViewsDescriptor()
257{
surmeh013537c2c2018-05-18 16:31:43 +0100258 if (m_ViewSizes)
telsoa014fcda012018-03-09 14:13:49 +0000259 {
surmeh013537c2c2018-05-18 16:31:43 +0100260 for (uint32_t i = 0; GetNumDimensions() > 0 && i < GetNumViews(); ++i)
261 {
262 delete[] m_ViewSizes[i];
263 }
264 delete[] m_ViewSizes;
telsoa014fcda012018-03-09 14:13:49 +0000265 }
telsoa014fcda012018-03-09 14:13:49 +0000266}
267
268ViewsDescriptor& ViewsDescriptor::operator=(ViewsDescriptor rhs)
269{
270 swap(*this, rhs);
271 return *this;
272}
273
Aron Virginas-Tar6fe52472019-10-15 17:35:36 +0100274bool ViewsDescriptor::operator==(const ViewsDescriptor& rhs) const
275{
276 if (GetNumViews() != rhs.GetNumViews() || GetNumDimensions() != rhs.GetNumDimensions())
277 {
278 return false;
279 }
280
281 for (unsigned int i = 0u; i < GetNumViews(); ++i)
282 {
283 for (unsigned int j = 0u; j < GetNumDimensions(); ++j)
284 {
285 if (GetViewOrigin(i)[j] != rhs.GetViewOrigin(i)[j] || GetViewSizes(i)[j] != rhs.GetViewSizes(i)[j])
286 {
287 return false;
288 }
289 }
290 }
291
292 return true;
293}
294
telsoa014fcda012018-03-09 14:13:49 +0000295uint32_t ViewsDescriptor::GetNumViews() const
296{
297 return m_Origins.GetNumViews();
298}
299
300uint32_t ViewsDescriptor::GetNumDimensions() const
301{
302 return m_Origins.GetNumDimensions();
303}
304
305const uint32_t* ViewsDescriptor::GetViewOrigin(uint32_t idx) const
306{
307 return m_Origins.GetViewOrigin(idx);
308}
309
310Status ViewsDescriptor::SetViewOriginCoord(uint32_t view, uint32_t coord, uint32_t value)
311{
312 return m_Origins.SetViewOriginCoord(view, coord, value);
313}
314
315Status ViewsDescriptor::SetViewSize(uint32_t view, uint32_t coord, uint32_t value)
316{
surmeh013537c2c2018-05-18 16:31:43 +0100317 if (!m_ViewSizes)
318 {
Derek Lamberti08446972019-11-26 16:38:31 +0000319 ARMNN_LOG(error) << "ViewsDescriptor::SetViewSize: invalid view sizes";
surmeh013537c2c2018-05-18 16:31:43 +0100320 return Status::Failure;
321 }
322
telsoa014fcda012018-03-09 14:13:49 +0000323 if (view >= GetNumViews())
324 {
Derek Lamberti08446972019-11-26 16:38:31 +0000325 ARMNN_LOG(error) << "ViewsDescriptor::SetViewSize: view argument:" << view <<
telsoa014fcda012018-03-09 14:13:49 +0000326 " is out of range";
327 return Status::Failure;
328 }
329 if (coord >= GetNumDimensions())
330 {
Derek Lamberti08446972019-11-26 16:38:31 +0000331 ARMNN_LOG(error) << "ViewsDescriptor::SetViewSize: coord argument:" << coord <<
telsoa014fcda012018-03-09 14:13:49 +0000332 " is out of range";
333 return Status::Failure;
334 }
335
336 m_ViewSizes[view][coord] = value;
337 return Status::Success;
338}
339
340const uint32_t* ViewsDescriptor::GetViewSizes(uint32_t idx) const
341{
342 return m_ViewSizes ? m_ViewSizes[idx] : nullptr;
343}
344
Jim Flynn18ce3382019-03-08 11:08:30 +0000345const OriginsDescriptor& ViewsDescriptor::GetOrigins() const
346{
347 return m_Origins;
348}
349
telsoa014fcda012018-03-09 14:13:49 +0000350void swap(OriginsDescriptor& first, OriginsDescriptor& second)
351{
352 using std::swap;
353 swap(first.m_NumViews, second.m_NumViews);
354 swap(first.m_NumDimensions, second.m_NumDimensions);
355 swap(first.m_ViewOrigins, second.m_ViewOrigins);
Nikhil Raj8599a412018-11-19 14:51:07 +0000356 swap(first.m_ConcatAxis, second.m_ConcatAxis);
telsoa014fcda012018-03-09 14:13:49 +0000357}
358
359void swap(ViewsDescriptor& first, ViewsDescriptor& second)
360{
361 using std::swap;
362 swap(first.m_Origins, second.m_Origins);
363 swap(first.m_ViewSizes, second.m_ViewSizes);
364}
365
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +0000366int StridedSliceDescriptor::GetStartForAxis(const TensorShape& inputShape,
367 unsigned int axis) const
368{
369 int start = m_Begin[axis];
370
371 if (m_BeginMask & (1 << axis))
372 {
373 if (m_Stride[axis] > 0)
374 {
375 start = std::numeric_limits<int>::min();
376 }
377 else
378 {
379 start = std::numeric_limits<int>::max();
380 }
381 }
382
Matthew Sloyan0663d662020-09-14 11:47:26 +0100383 const int axisSize = armnn::numeric_cast<int>(inputShape[axis]);
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +0000384 if (start < 0)
385 {
386 start += (axisSize);
387 }
388
389 return std::max(0, std::min(start, axisSize - 1));
390
391}
392
393int StridedSliceDescriptor::GetStopForAxis(const TensorShape& inputShape,
394 unsigned int axis,
395 int startForAxis) const
396{
397
398 if (m_ShrinkAxisMask & (1 << axis))
399 {
400 return startForAxis + 1;
401 }
402
403 int stop = m_End[axis];
404
405 if (m_EndMask & (1 << axis))
406 {
407 if (m_Stride[axis] > 0)
408 {
409 stop = std::numeric_limits<int>::max();
410 }
411 else
412 {
413 stop = std::numeric_limits<int>::min();
414 }
415 }
416
Matthew Sloyan0663d662020-09-14 11:47:26 +0100417 const int axisSize = armnn::numeric_cast<int>(inputShape[axis]);
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +0000418 if (stop < 0)
419 {
420 stop += axisSize;
421 }
422
423 return m_Stride[axis] > 0 ? std::max(0, std::min(stop, axisSize)) :
424 std::max(-1, std::min(stop, axisSize - 1));
425
426}
427
Francis Murtagh080d5b72021-08-17 15:38:24 +0100428uint32_t FullyConnectedDescriptor::GetNumViews() const
429{
430 return GetNumInputs();
431}
432
Matthew Sloyan81beae32021-07-13 19:46:11 +0100433uint32_t FullyConnectedDescriptor::GetNumInputs() const
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000434{
Matthew Sloyan81beae32021-07-13 19:46:11 +0100435 // Return 2 otherwise check if bias is enabled
436 unsigned int numInputs = 2;
437 if (m_BiasEnabled)
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000438 {
Matthew Sloyan81beae32021-07-13 19:46:11 +0100439 numInputs = 3;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000440 }
441 return numInputs;
442}
443
telsoa014fcda012018-03-09 14:13:49 +0000444}