blob: bf5ffdf0adb45460bf4379889d455699550c8cba [file] [log] [blame]
Sadik Armagan479045b2018-10-01 11:51:37 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ParserHelper.hpp"
7
8// armnnUtils
9#include "Permute.hpp"
10
11#include <boost/format.hpp>
12
13namespace armnnUtils
14{
15
16const armnn::PermutationVector NHWCToArmNN = { 0, 2, 3, 1 };
17const armnn::PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 };
18
19void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::OriginsDescriptor& concatDescriptor,
20 const unsigned int& concatAxis, unsigned int inputIndex,
21 std::vector<unsigned int>& mergeDimSizes, unsigned int& mergeDim)
22{
23 // double check dimensions of the tensors
24 if (inputTensorInfo.GetNumDimensions() != armnn::MaxNumOfTensorDimensions)
25 {
26 throw armnn::ParseException(
27 boost::str(
28 boost::format(
29 "The number of dimensions: %1% for input tensors of the "
30 "concatenation op should be %2% %3%")
31 % inputTensorInfo.GetNumDimensions()
32 % armnn::MaxNumOfTensorDimensions
33 % CHECK_LOCATION().AsString()));
34 }
35
36 // if concatenation axis is 3 then need to be permuted
37 if (concatAxis == 3)
38 {
39 inputTensorInfo = armnnUtils::Permuted(inputTensorInfo, NHWCToArmNN);
40 }
41
42 for (unsigned int dim = 0; dim < armnn::MaxNumOfTensorDimensions; ++dim)
43 {
44 mergeDimSizes[dim] = inputTensorInfo.GetShape()[dim];
45 }
46
47 // Concatenation dimension 1 is the only dimension supported in ArmNN
48 const unsigned int concatenationDim = 1;
49
50 for (unsigned int j = 0; j < concatenationDim; ++j)
51 {
52 concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
53 }
54
55 concatDescriptor.SetViewOriginCoord(inputIndex, concatenationDim, mergeDim);
56 mergeDim += mergeDimSizes[concatenationDim];
57
58 for (unsigned int j = concatenationDim + 1; j < armnn::MaxNumOfTensorDimensions; ++j)
59 {
60 concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
61 }
62}
63
64} // namespace armnnUtils