blob: de27630e89f104f3a8758b829e216605ba3aba87 [file] [log] [blame]
Aron Virginas-Tarf03fcf02019-07-09 17:44:24 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "OutputShapeUtils.hpp"
7
8#include <algorithm>
9
10namespace armnn_driver
11{
12
13using namespace armnn;
14
15TensorShape InferPreluOutputShape(const TensorShape& inputShape, const TensorShape& alphaShape)
16{
17 // NOTE: The inferred PReLU output size will be the maximum size along each dimension
18 // of input and alpha, starting with the trailing dimensions, and working its way forward.
19 //
20 // Example: inputShape={4, 1, 2}, alphaShape={5, 4, 3, 1} => outputShape={5, 4, 3, 2}
21
22 const unsigned int numInputDims = inputShape.GetNumDimensions();
23 const unsigned int numAlphaDims = alphaShape.GetNumDimensions();
24
25 const unsigned int maxNumDims = std::max(numInputDims, numAlphaDims);
26
27 TensorShape outputShape = TensorShape(maxNumDims);
28 for (unsigned int reverseIdx = 1u; reverseIdx <= maxNumDims; ++reverseIdx)
29 {
30 const int inputIdx = numInputDims - reverseIdx;
31 const int alphaIdx = numAlphaDims - reverseIdx;
32
33 const unsigned int inputDimSize = inputIdx >= 0 ? inputShape[inputIdx] : 0u;
34 const unsigned int alphaDimSize = alphaIdx >= 0 ? alphaShape[alphaIdx] : 0u;
35
36 const unsigned int outputIdx = maxNumDims - reverseIdx;
37 outputShape[outputIdx] = std::max(inputDimSize, alphaDimSize);
38 }
39
40 return outputShape;
41}
42
43} // namespace armnn_driver