blob: 2c25eec163cf93975be28de711aee9ce483b1bfc [file] [log] [blame]
Nina Drozdd41b2592018-11-19 13:03:36 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "TensorUtils.hpp"
7
8namespace armnnUtils
9{
10
11armnn::TensorShape GetTensorShape(unsigned int numberOfBatches,
12 unsigned int numberOfChannels,
13 unsigned int height,
14 unsigned int width,
15 const armnn::DataLayout dataLayout)
16{
17 switch (dataLayout)
18 {
19 case armnn::DataLayout::NCHW:
20 return armnn::TensorShape({numberOfBatches, numberOfChannels, height, width});
21 case armnn::DataLayout::NHWC:
22 return armnn::TensorShape({numberOfBatches, height, width, numberOfChannels});
23 default:
24 throw armnn::InvalidArgumentException("Unknown data layout ["
25 + std::to_string(static_cast<int>(dataLayout)) +
26 "]", CHECK_LOCATION());
27 }
28}
29
30}
31