blob: af8014d112fe475018c225a7df7f00f0b3a78038 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Sadik Armagan479045b2018-10-01 11:51:37 +01002// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ParserHelper.hpp"
7
Matthew Benthamff130e22020-01-17 11:47:42 +00008#include <armnn/Descriptors.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00009#include <armnnUtils/Permute.hpp>
Sadik Armagan479045b2018-10-01 11:51:37 +010010
Colm Donelan5b5c2222020-09-09 12:48:16 +010011#include <fmt/format.h>
Sadik Armagan479045b2018-10-01 11:51:37 +010012
13namespace armnnUtils
14{
15
16const armnn::PermutationVector NHWCToArmNN = { 0, 2, 3, 1 };
17const armnn::PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 };
18
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000019void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo,
20 armnn::OriginsDescriptor& concatDescriptor,
21 const unsigned int& concatAxis,
22 unsigned int inputIndex,
23 unsigned int& mergeDimOrigin)
Sadik Armagan479045b2018-10-01 11:51:37 +010024{
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000025 const uint32_t inputRank = concatDescriptor.GetNumDimensions();
26
Sadik Armagan479045b2018-10-01 11:51:37 +010027 // double check dimensions of the tensors
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000028 if (inputTensorInfo.GetNumDimensions() != inputRank)
Sadik Armagan479045b2018-10-01 11:51:37 +010029 {
Colm Donelan5b5c2222020-09-09 12:48:16 +010030 throw armnn::ParseException(fmt::format(
31 "The number of dimensions: {0} for input tensors of the "
32 "concatenation op should be {1} {2}",
33 inputTensorInfo.GetNumDimensions(),
34 inputRank,
35 CHECK_LOCATION().AsString()));
Sadik Armagan479045b2018-10-01 11:51:37 +010036 }
37
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000038 for (unsigned int j = 0; j < concatAxis; ++j)
Sadik Armagan479045b2018-10-01 11:51:37 +010039 {
40 concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
41 }
42
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000043 concatDescriptor.SetViewOriginCoord(inputIndex, concatAxis, mergeDimOrigin);
44 mergeDimOrigin += inputTensorInfo.GetShape()[concatAxis];
Sadik Armagan479045b2018-10-01 11:51:37 +010045
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000046 for (unsigned int j = concatAxis + 1; j < inputRank; ++j)
Sadik Armagan479045b2018-10-01 11:51:37 +010047 {
48 concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
49 }
50}
51
Derek Lambertibaa177f2019-12-10 22:00:43 +000052void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo,
53 const std::set<unsigned int>& axisSet,
54 bool keepDims,
Ferran Balaguer51dd62f2019-01-11 19:29:18 +000055 armnn::TensorInfo& outputTensorInfo)
56{
57 std::vector<unsigned int> outputShapeVector;
58 bool dimensionFound = false;
59 unsigned int size = 1;
60
61 for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); ++i)
62 {
63 dimensionFound = false;
64 for (unsigned int axis: axisSet)
65 {
66 if (axis == i)
67 {
68 dimensionFound = true;
69 break;
70 }
71 }
72
73 if (!dimensionFound)
74 {
75 size *= inputTensorInfo.GetShape()[i];
76
77 if (keepDims)
78 {
79 outputShapeVector.push_back(inputTensorInfo.GetShape()[i]);
80 }
81 }
82 else
83 {
84 if (keepDims)
85 {
86 outputShapeVector.push_back(1);
87 }
88 }
89 }
90
91 if (keepDims)
92 {
93 armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]);
94 outputTensorInfo = armnn::TensorInfo(outputTensorShape, inputTensorInfo.GetDataType());
95 }
96 else
97 {
98 outputTensorInfo = armnn::TensorInfo({size}, inputTensorInfo.GetDataType());
99 }
100}
101
Georgios Pinitas5e90aab2020-02-14 14:46:51 +0000102
103void CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo& inputTensorInfo,
104 const armnn::StridedSliceDescriptor& desc,
105 armnn::TensorInfo& outputTensorInfo)
106{
107 const armnn::TensorShape& inputShape = inputTensorInfo.GetShape();
108
109 std::vector<unsigned int> outputShapeVector;
110 for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++)
111 {
112 if (desc.m_ShrinkAxisMask & (1 << i))
113 {
114 continue;
115 }
116
117 int stride = desc.m_Stride[i];
118 int start = desc.GetStartForAxis(inputShape, i);
119 int stop = desc.GetStopForAxis(inputShape, i, start);
120
121 int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride :
122 ((start - stop) - stride - 1) / -stride;
123
124 newSize = std::max(0, newSize);
125
126 outputShapeVector.push_back(static_cast<unsigned int>(newSize));
127 }
128
129 armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]);
130 outputTensorInfo = armnn::TensorInfo(armnn::TensorShape(outputTensorShape), inputTensorInfo.GetDataType());
131}
Sadik Armagan479045b2018-10-01 11:51:37 +0100132} // namespace armnnUtils