blob: 9709773014b4718f59c452c5e084afc93d0707c7 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Sadik Armagan479045b2018-10-01 11:51:37 +01002// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ParserHelper.hpp"
7
Matthew Benthamff130e22020-01-17 11:47:42 +00008#include <armnn/Descriptors.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00009#include <armnnUtils/Permute.hpp>
Sadik Armagan479045b2018-10-01 11:51:37 +010010
11#include <boost/format.hpp>
12
13namespace armnnUtils
14{
15
16const armnn::PermutationVector NHWCToArmNN = { 0, 2, 3, 1 };
17const armnn::PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 };
18
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000019void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo,
20 armnn::OriginsDescriptor& concatDescriptor,
21 const unsigned int& concatAxis,
22 unsigned int inputIndex,
23 unsigned int& mergeDimOrigin)
Sadik Armagan479045b2018-10-01 11:51:37 +010024{
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000025 const uint32_t inputRank = concatDescriptor.GetNumDimensions();
26
Sadik Armagan479045b2018-10-01 11:51:37 +010027 // double check dimensions of the tensors
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000028 if (inputTensorInfo.GetNumDimensions() != inputRank)
Sadik Armagan479045b2018-10-01 11:51:37 +010029 {
30 throw armnn::ParseException(
31 boost::str(
32 boost::format(
33 "The number of dimensions: %1% for input tensors of the "
34 "concatenation op should be %2% %3%")
35 % inputTensorInfo.GetNumDimensions()
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000036 % inputRank
Sadik Armagan479045b2018-10-01 11:51:37 +010037 % CHECK_LOCATION().AsString()));
38 }
39
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000040 for (unsigned int j = 0; j < concatAxis; ++j)
Sadik Armagan479045b2018-10-01 11:51:37 +010041 {
42 concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
43 }
44
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000045 concatDescriptor.SetViewOriginCoord(inputIndex, concatAxis, mergeDimOrigin);
46 mergeDimOrigin += inputTensorInfo.GetShape()[concatAxis];
Sadik Armagan479045b2018-10-01 11:51:37 +010047
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000048 for (unsigned int j = concatAxis + 1; j < inputRank; ++j)
Sadik Armagan479045b2018-10-01 11:51:37 +010049 {
50 concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
51 }
52}
53
Derek Lambertibaa177f2019-12-10 22:00:43 +000054void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo,
55 const std::set<unsigned int>& axisSet,
56 bool keepDims,
Ferran Balaguer51dd62f2019-01-11 19:29:18 +000057 armnn::TensorInfo& outputTensorInfo)
58{
59 std::vector<unsigned int> outputShapeVector;
60 bool dimensionFound = false;
61 unsigned int size = 1;
62
63 for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); ++i)
64 {
65 dimensionFound = false;
66 for (unsigned int axis: axisSet)
67 {
68 if (axis == i)
69 {
70 dimensionFound = true;
71 break;
72 }
73 }
74
75 if (!dimensionFound)
76 {
77 size *= inputTensorInfo.GetShape()[i];
78
79 if (keepDims)
80 {
81 outputShapeVector.push_back(inputTensorInfo.GetShape()[i]);
82 }
83 }
84 else
85 {
86 if (keepDims)
87 {
88 outputShapeVector.push_back(1);
89 }
90 }
91 }
92
93 if (keepDims)
94 {
95 armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]);
96 outputTensorInfo = armnn::TensorInfo(outputTensorShape, inputTensorInfo.GetDataType());
97 }
98 else
99 {
100 outputTensorInfo = armnn::TensorInfo({size}, inputTensorInfo.GetDataType());
101 }
102}
103
Georgios Pinitas5e90aab2020-02-14 14:46:51 +0000104
105void CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo& inputTensorInfo,
106 const armnn::StridedSliceDescriptor& desc,
107 armnn::TensorInfo& outputTensorInfo)
108{
109 const armnn::TensorShape& inputShape = inputTensorInfo.GetShape();
110
111 std::vector<unsigned int> outputShapeVector;
112 for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++)
113 {
114 if (desc.m_ShrinkAxisMask & (1 << i))
115 {
116 continue;
117 }
118
119 int stride = desc.m_Stride[i];
120 int start = desc.GetStartForAxis(inputShape, i);
121 int stop = desc.GetStopForAxis(inputShape, i, start);
122
123 int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride :
124 ((start - stop) - stride - 1) / -stride;
125
126 newSize = std::max(0, newSize);
127
128 outputShapeVector.push_back(static_cast<unsigned int>(newSize));
129 }
130
131 armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]);
132 outputTensorInfo = armnn::TensorInfo(armnn::TensorShape(outputTensorShape), inputTensorInfo.GetDataType());
133}
Sadik Armagan479045b2018-10-01 11:51:37 +0100134} // namespace armnnUtils