blob: b6cdb3160da251017169b72736197a68ef44d6a7 [file] [log] [blame]
Aron Virginas-Tarf03fcf02019-07-09 17:44:24 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "OutputShapeUtils.hpp"
7
8#include <algorithm>
Sadik Armagan310d8ff2019-07-11 10:53:38 +01009#include <vector>
Aron Virginas-Tarf03fcf02019-07-09 17:44:24 +010010
Sadik Armagan5e9521c2019-07-12 13:55:57 +010011namespace
12{
13
14using namespace armnn;
15
16TensorShape CalculateMaxShape(const TensorShape& inShape0, const TensorShape& inShape1)
17{
18 // NOTE: The inferred output size will be the maximum size along each dimension
19 // of inShape0 and inShape1, starting with the trailing dimensions, and working its way forward.
20 //
21 // Example: inShape0={4, 1, 2}, inShape1={5, 4, 3, 1} => outputShape={5, 4, 3, 2}
22
23 const unsigned int numInput0Dims = inShape0.GetNumDimensions();
24 const unsigned int numInput1Dims = inShape1.GetNumDimensions();
25
26 const unsigned int maxNumDims = std::max(numInput0Dims, numInput1Dims);
27
28 TensorShape outputShape = TensorShape(maxNumDims);
29 for (unsigned int reverseIdx = 1u; reverseIdx <= maxNumDims; ++reverseIdx)
30 {
31 const int input0Idx = numInput0Dims - reverseIdx;
32 const int input1Idx = numInput1Dims - reverseIdx;
33
34 const unsigned int input0DimSize = input0Idx >= 0 ? inShape0[input0Idx] : 0u;
35 const unsigned int input1DimSize = input1Idx >= 0 ? inShape1[input1Idx] : 0u;
36
37 const unsigned int outputIdx = maxNumDims - reverseIdx;
38 outputShape[outputIdx] = std::max(input0DimSize, input1DimSize);
39 }
40
41 return outputShape;
42}
43
44} // namespace annonymous
45
46
Aron Virginas-Tarf03fcf02019-07-09 17:44:24 +010047namespace armnn_driver
48{
49
50using namespace armnn;
51
Aron Virginas-Tar366e0a62019-07-10 13:01:41 +010052bool IsDynamicOutput(const TensorInfo& outputInfo)
53{
54 return outputInfo.GetNumElements() == 0u;
55}
56
Narumol Prangnawarat95b1ef62019-07-15 12:02:20 +010057TensorShape InferMaximumOutputShape(const armnn::TensorShape& input0Shape,
58 const armnn::TensorShape& input1Shape)
59{
60 return CalculateMaxShape(input0Shape, input1Shape);
61}
62
Sadik Armagan310d8ff2019-07-11 10:53:38 +010063TensorShape InferPadOutputShape(const TensorShape& inputShape,
64 const std::vector<std::pair<unsigned int, unsigned int>>& padList)
65{
66 const unsigned int numDims = inputShape.GetNumDimensions();
67
68 std::vector<unsigned int> outputDims;
69 TensorShape outputShape = TensorShape(numDims);
70 for (unsigned int dim = 0; dim < numDims; ++dim)
71 {
72 unsigned int dimSize = inputShape[dim];
73 const std::pair<unsigned int, unsigned int>& dimPadding = padList[dim];
74 dimSize += dimPadding.first;
75 dimSize += dimPadding.second;
76 outputShape[dim] = dimSize;
77 }
78 return outputShape;
79}
80
Aron Virginas-Tarf03fcf02019-07-09 17:44:24 +010081TensorShape InferPreluOutputShape(const TensorShape& inputShape, const TensorShape& alphaShape)
82{
Sadik Armagan5e9521c2019-07-12 13:55:57 +010083 return CalculateMaxShape(inputShape, alphaShape);
84}
Aron Virginas-Tarf03fcf02019-07-09 17:44:24 +010085
Sadik Armagan5e9521c2019-07-12 13:55:57 +010086TensorShape InferSubOutputShape(const TensorShape& input0Shape, const TensorShape& input1Shape)
87{
88 return CalculateMaxShape(input0Shape, input1Shape);
Aron Virginas-Tarf03fcf02019-07-09 17:44:24 +010089}
90
91} // namespace armnn_driver