blob: 2286f8b6ed71c2a8570aa07dfc87562304efc116 [file] [log] [blame]
Sadik Armagan479045b2018-10-01 11:51:37 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ParserHelper.hpp"
7
8// armnnUtils
9#include "Permute.hpp"
10
11#include <boost/format.hpp>
12
13namespace armnnUtils
14{
15
16const armnn::PermutationVector NHWCToArmNN = { 0, 2, 3, 1 };
17const armnn::PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 };
18
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000019void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo,
20 armnn::OriginsDescriptor& concatDescriptor,
21 const unsigned int& concatAxis,
22 unsigned int inputIndex,
23 unsigned int& mergeDimOrigin)
Sadik Armagan479045b2018-10-01 11:51:37 +010024{
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000025 const uint32_t inputRank = concatDescriptor.GetNumDimensions();
26
Sadik Armagan479045b2018-10-01 11:51:37 +010027 // double check dimensions of the tensors
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000028 if (inputTensorInfo.GetNumDimensions() != inputRank)
Sadik Armagan479045b2018-10-01 11:51:37 +010029 {
30 throw armnn::ParseException(
31 boost::str(
32 boost::format(
33 "The number of dimensions: %1% for input tensors of the "
34 "concatenation op should be %2% %3%")
35 % inputTensorInfo.GetNumDimensions()
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000036 % inputRank
Sadik Armagan479045b2018-10-01 11:51:37 +010037 % CHECK_LOCATION().AsString()));
38 }
39
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000040 for (unsigned int j = 0; j < concatAxis; ++j)
Sadik Armagan479045b2018-10-01 11:51:37 +010041 {
42 concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
43 }
44
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000045 concatDescriptor.SetViewOriginCoord(inputIndex, concatAxis, mergeDimOrigin);
46 mergeDimOrigin += inputTensorInfo.GetShape()[concatAxis];
Sadik Armagan479045b2018-10-01 11:51:37 +010047
Nattapat Chaimanowong5e9d2982019-01-25 13:20:39 +000048 for (unsigned int j = concatAxis + 1; j < inputRank; ++j)
Sadik Armagan479045b2018-10-01 11:51:37 +010049 {
50 concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
51 }
52}
53
Ferran Balaguer51dd62f2019-01-11 19:29:18 +000054void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo, const armnn::TensorInfo& axisTensorInfo,
55 const std::set<unsigned int>& axisSet, bool keepDims,
56 armnn::TensorInfo& outputTensorInfo)
57{
58 std::vector<unsigned int> outputShapeVector;
59 bool dimensionFound = false;
60 unsigned int size = 1;
61
62 for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); ++i)
63 {
64 dimensionFound = false;
65 for (unsigned int axis: axisSet)
66 {
67 if (axis == i)
68 {
69 dimensionFound = true;
70 break;
71 }
72 }
73
74 if (!dimensionFound)
75 {
76 size *= inputTensorInfo.GetShape()[i];
77
78 if (keepDims)
79 {
80 outputShapeVector.push_back(inputTensorInfo.GetShape()[i]);
81 }
82 }
83 else
84 {
85 if (keepDims)
86 {
87 outputShapeVector.push_back(1);
88 }
89 }
90 }
91
92 if (keepDims)
93 {
94 armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]);
95 outputTensorInfo = armnn::TensorInfo(outputTensorShape, inputTensorInfo.GetDataType());
96 }
97 else
98 {
99 outputTensorInfo = armnn::TensorInfo({size}, inputTensorInfo.GetDataType());
100 }
101}
102
Sadik Armagan479045b2018-10-01 11:51:37 +0100103} // namespace armnnUtils