blob: 6461b37f75710cd1ff114d971f8ecb2a1acce6d9 [file] [log] [blame]
Nina Drozdd41b2592018-11-19 13:03:36 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/TypesUtils.hpp>
9
10namespace armnnUtils
11{
12armnn::TensorShape GetTensorShape(unsigned int numberOfBatches,
13 unsigned int numberOfChannels,
14 unsigned int height,
15 unsigned int width,
16 const armnn::DataLayout dataLayout);
17
18template<typename T>
19armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches,
20 unsigned int numberOfChannels,
21 unsigned int height,
22 unsigned int width,
23 const armnn::DataLayout dataLayout)
24{
25 switch (dataLayout)
26 {
27 case armnn::DataLayout::NCHW:
28 return armnn::TensorInfo({numberOfBatches, numberOfChannels, height, width}, armnn::GetDataType<T>());
29 case armnn::DataLayout::NHWC:
30 return armnn::TensorInfo({numberOfBatches, height, width, numberOfChannels}, armnn::GetDataType<T>());
31 default:
32 throw armnn::InvalidArgumentException("Unknown data layout ["
33 + std::to_string(static_cast<int>(dataLayout)) +
34 "]", CHECK_LOCATION());
35 }
36}
37} // namespace armnnUtils