blob: 739038ac2931d622ca6824965c7bc747380045f6 [file] [log] [blame]
Aron Virginas-Tarf03fcf02019-07-09 17:44:24 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "OutputShapeUtils.hpp"
7
8#include <algorithm>
Sadik Armagan310d8ff2019-07-11 10:53:38 +01009#include <vector>
Aron Virginas-Tarf03fcf02019-07-09 17:44:24 +010010
Sadik Armagan5e9521c2019-07-12 13:55:57 +010011namespace
12{
13
14using namespace armnn;
15
16TensorShape CalculateMaxShape(const TensorShape& inShape0, const TensorShape& inShape1)
17{
18 // NOTE: The inferred output size will be the maximum size along each dimension
19 // of inShape0 and inShape1, starting with the trailing dimensions, and working its way forward.
20 //
21 // Example: inShape0={4, 1, 2}, inShape1={5, 4, 3, 1} => outputShape={5, 4, 3, 2}
22
23 const unsigned int numInput0Dims = inShape0.GetNumDimensions();
24 const unsigned int numInput1Dims = inShape1.GetNumDimensions();
25
26 const unsigned int maxNumDims = std::max(numInput0Dims, numInput1Dims);
27
28 TensorShape outputShape = TensorShape(maxNumDims);
29 for (unsigned int reverseIdx = 1u; reverseIdx <= maxNumDims; ++reverseIdx)
30 {
31 const int input0Idx = numInput0Dims - reverseIdx;
32 const int input1Idx = numInput1Dims - reverseIdx;
33
34 const unsigned int input0DimSize = input0Idx >= 0 ? inShape0[input0Idx] : 0u;
35 const unsigned int input1DimSize = input1Idx >= 0 ? inShape1[input1Idx] : 0u;
36
37 const unsigned int outputIdx = maxNumDims - reverseIdx;
38 outputShape[outputIdx] = std::max(input0DimSize, input1DimSize);
39 }
40
41 return outputShape;
42}
43
44} // namespace annonymous
45
46
Aron Virginas-Tarf03fcf02019-07-09 17:44:24 +010047namespace armnn_driver
48{
49
50using namespace armnn;
51
Aron Virginas-Tar366e0a62019-07-10 13:01:41 +010052bool IsDynamicOutput(const TensorInfo& outputInfo)
53{
54 return outputInfo.GetNumElements() == 0u;
55}
56
Sadik Armagan310d8ff2019-07-11 10:53:38 +010057TensorShape InferPadOutputShape(const TensorShape& inputShape,
58 const std::vector<std::pair<unsigned int, unsigned int>>& padList)
59{
60 const unsigned int numDims = inputShape.GetNumDimensions();
61
62 std::vector<unsigned int> outputDims;
63 TensorShape outputShape = TensorShape(numDims);
64 for (unsigned int dim = 0; dim < numDims; ++dim)
65 {
66 unsigned int dimSize = inputShape[dim];
67 const std::pair<unsigned int, unsigned int>& dimPadding = padList[dim];
68 dimSize += dimPadding.first;
69 dimSize += dimPadding.second;
70 outputShape[dim] = dimSize;
71 }
72 return outputShape;
73}
74
Aron Virginas-Tarf03fcf02019-07-09 17:44:24 +010075TensorShape InferPreluOutputShape(const TensorShape& inputShape, const TensorShape& alphaShape)
76{
Sadik Armagan5e9521c2019-07-12 13:55:57 +010077 return CalculateMaxShape(inputShape, alphaShape);
78}
Aron Virginas-Tarf03fcf02019-07-09 17:44:24 +010079
Sadik Armagan5e9521c2019-07-12 13:55:57 +010080TensorShape InferSubOutputShape(const TensorShape& input0Shape, const TensorShape& input1Shape)
81{
82 return CalculateMaxShape(input0Shape, input1Shape);
Aron Virginas-Tarf03fcf02019-07-09 17:44:24 +010083}
84
85} // namespace armnn_driver