blob: 57f823fe13afd6d959cba15f615e47f961c264bd [file] [log] [blame]
Nina Drozdd41b2592018-11-19 13:03:36 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "TensorUtils.hpp"
7
8namespace armnnUtils
9{
10
11armnn::TensorShape GetTensorShape(unsigned int numberOfBatches,
12 unsigned int numberOfChannels,
13 unsigned int height,
14 unsigned int width,
15 const armnn::DataLayout dataLayout)
16{
17 switch (dataLayout)
18 {
19 case armnn::DataLayout::NCHW:
20 return armnn::TensorShape({numberOfBatches, numberOfChannels, height, width});
21 case armnn::DataLayout::NHWC:
22 return armnn::TensorShape({numberOfBatches, height, width, numberOfChannels});
23 default:
24 throw armnn::InvalidArgumentException("Unknown data layout ["
25 + std::to_string(static_cast<int>(dataLayout)) +
26 "]", CHECK_LOCATION());
27 }
28}
29
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000030armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches,
31 unsigned int numberOfChannels,
32 unsigned int height,
33 unsigned int width,
34 const armnn::DataLayout dataLayout,
35 const armnn::DataType dataType)
36{
37 switch (dataLayout)
38 {
39 case armnn::DataLayout::NCHW:
40 return armnn::TensorInfo({numberOfBatches, numberOfChannels, height, width}, dataType);
41 case armnn::DataLayout::NHWC:
42 return armnn::TensorInfo({numberOfBatches, height, width, numberOfChannels}, dataType);
43 default:
44 throw armnn::InvalidArgumentException("Unknown data layout ["
45 + std::to_string(static_cast<int>(dataLayout)) +
46 "]", CHECK_LOCATION());
47 }
Nina Drozdd41b2592018-11-19 13:03:36 +000048}
49
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000050}