blob: ca6e42696efa59635354bc8d51fb1e944d90d177 [file] [log] [blame]
Sadik Armagan479045b2018-10-01 11:51:37 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ParserHelper.hpp"
7
Matthew Benthamff130e22020-01-17 11:47:42 +00008#include <armnn/Descriptors.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00009#include <armnnUtils/Permute.hpp>
Sadik Armagan479045b2018-10-01 11:51:37 +010010
11#include <boost/format.hpp>
12
13namespace armnnUtils
14{
15
16const armnn::PermutationVector NHWCToArmNN = { 0, 2, 3, 1 };
17const armnn::PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 };
18
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000019void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo,
20 armnn::OriginsDescriptor& concatDescriptor,
21 const unsigned int& concatAxis,
22 unsigned int inputIndex,
23 unsigned int& mergeDimOrigin)
Sadik Armagan479045b2018-10-01 11:51:37 +010024{
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000025 const uint32_t inputRank = concatDescriptor.GetNumDimensions();
26
Sadik Armagan479045b2018-10-01 11:51:37 +010027 // double check dimensions of the tensors
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000028 if (inputTensorInfo.GetNumDimensions() != inputRank)
Sadik Armagan479045b2018-10-01 11:51:37 +010029 {
30 throw armnn::ParseException(
31 boost::str(
32 boost::format(
33 "The number of dimensions: %1% for input tensors of the "
34 "concatenation op should be %2% %3%")
35 % inputTensorInfo.GetNumDimensions()
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000036 % inputRank
Sadik Armagan479045b2018-10-01 11:51:37 +010037 % CHECK_LOCATION().AsString()));
38 }
39
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000040 for (unsigned int j = 0; j < concatAxis; ++j)
Sadik Armagan479045b2018-10-01 11:51:37 +010041 {
42 concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
43 }
44
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000045 concatDescriptor.SetViewOriginCoord(inputIndex, concatAxis, mergeDimOrigin);
46 mergeDimOrigin += inputTensorInfo.GetShape()[concatAxis];
Sadik Armagan479045b2018-10-01 11:51:37 +010047
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000048 for (unsigned int j = concatAxis + 1; j < inputRank; ++j)
Sadik Armagan479045b2018-10-01 11:51:37 +010049 {
50 concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
51 }
52}
53
Derek Lambertibaa177f2019-12-10 22:00:43 +000054void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo,
55 const std::set<unsigned int>& axisSet,
56 bool keepDims,
Ferran Balaguer51dd62f2019-01-11 19:29:18 +000057 armnn::TensorInfo& outputTensorInfo)
58{
59 std::vector<unsigned int> outputShapeVector;
60 bool dimensionFound = false;
61 unsigned int size = 1;
62
63 for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); ++i)
64 {
65 dimensionFound = false;
66 for (unsigned int axis: axisSet)
67 {
68 if (axis == i)
69 {
70 dimensionFound = true;
71 break;
72 }
73 }
74
75 if (!dimensionFound)
76 {
77 size *= inputTensorInfo.GetShape()[i];
78
79 if (keepDims)
80 {
81 outputShapeVector.push_back(inputTensorInfo.GetShape()[i]);
82 }
83 }
84 else
85 {
86 if (keepDims)
87 {
88 outputShapeVector.push_back(1);
89 }
90 }
91 }
92
93 if (keepDims)
94 {
95 armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]);
96 outputTensorInfo = armnn::TensorInfo(outputTensorShape, inputTensorInfo.GetDataType());
97 }
98 else
99 {
100 outputTensorInfo = armnn::TensorInfo({size}, inputTensorInfo.GetDataType());
101 }
102}
103
Sadik Armagan479045b2018-10-01 11:51:37 +0100104} // namespace armnnUtils