blob: b26f22043b6c14305cb758ec96cbdcd5efe0ecff [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Matteo Martincighe011d202019-11-28 11:35:47 +00002// Copyright © 2019 Arm Ltd. All rights reserved.
Matteo Martincigh21350152018-11-28 16:22:22 +00003// SPDX-License-Identifier: MIT
4//
Matteo Martincighee423ce2019-06-05 09:02:41 +01005
Matteo Martincigh21350152018-11-28 16:22:22 +00006#pragma once
Matteo Martincighee423ce2019-06-05 09:02:41 +01007
Matteo Martincigh21350152018-11-28 16:22:22 +00008#include <armnn/Types.hpp>
Matteo Martincighee423ce2019-06-05 09:02:41 +01009#include <armnn/Tensor.hpp>
Matteo Martincigh21350152018-11-28 16:22:22 +000010
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010011#include <armnn/utility/Assert.hpp>
Matteo Martincighf2aaab32019-06-06 15:46:22 +010012
Matteo Martincigh21350152018-11-28 16:22:22 +000013namespace armnnUtils
14{
15
Ryan OShea2bbfaa72020-02-12 16:15:27 +000016/// Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout
Matteo Martincigh21350152018-11-28 16:22:22 +000017class DataLayoutIndexed
18{
19public:
20 DataLayoutIndexed(armnn::DataLayout dataLayout);
21
22 armnn::DataLayout GetDataLayout() const { return m_DataLayout; }
23 unsigned int GetChannelsIndex() const { return m_ChannelsIndex; }
24 unsigned int GetHeightIndex() const { return m_HeightIndex; }
25 unsigned int GetWidthIndex() const { return m_WidthIndex; }
Matteo Martincighf2aaab32019-06-06 15:46:22 +010026
27 inline unsigned int GetIndex(const armnn::TensorShape& shape,
28 unsigned int batchIndex, unsigned int channelIndex,
29 unsigned int heightIndex, unsigned int widthIndex) const
30 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010031 ARMNN_ASSERT( batchIndex < shape[0] || ( shape[0] == 0 && batchIndex == 0 ) );
32 ARMNN_ASSERT( channelIndex < shape[m_ChannelsIndex] ||
Matteo Martincighf2aaab32019-06-06 15:46:22 +010033 ( shape[m_ChannelsIndex] == 0 && channelIndex == 0) );
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010034 ARMNN_ASSERT( heightIndex < shape[m_HeightIndex] ||
Matteo Martincighf2aaab32019-06-06 15:46:22 +010035 ( shape[m_HeightIndex] == 0 && heightIndex == 0) );
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010036 ARMNN_ASSERT( widthIndex < shape[m_WidthIndex] ||
Matteo Martincighf2aaab32019-06-06 15:46:22 +010037 ( shape[m_WidthIndex] == 0 && widthIndex == 0) );
38
Ryan OShea2bbfaa72020-02-12 16:15:27 +000039 /// Offset the given indices appropriately depending on the data layout
Matteo Martincighf2aaab32019-06-06 15:46:22 +010040 switch (m_DataLayout)
41 {
42 case armnn::DataLayout::NHWC:
43 batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
44 heightIndex *= shape[m_WidthIndex] * shape[m_ChannelsIndex];
45 widthIndex *= shape[m_ChannelsIndex];
Ryan OShea2bbfaa72020-02-12 16:15:27 +000046 /// channelIndex stays unchanged
Matteo Martincighf2aaab32019-06-06 15:46:22 +010047 break;
48 case armnn::DataLayout::NCHW:
49 default:
50 batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
51 channelIndex *= shape[m_HeightIndex] * shape[m_WidthIndex];
52 heightIndex *= shape[m_WidthIndex];
Ryan OShea2bbfaa72020-02-12 16:15:27 +000053 /// widthIndex stays unchanged
Matteo Martincighf2aaab32019-06-06 15:46:22 +010054 break;
55 }
56
Ryan OShea2bbfaa72020-02-12 16:15:27 +000057 /// Get the value using the correct offset
Matteo Martincighf2aaab32019-06-06 15:46:22 +010058 return batchIndex + channelIndex + heightIndex + widthIndex;
59 }
Matteo Martincigh21350152018-11-28 16:22:22 +000060
61private:
62 armnn::DataLayout m_DataLayout;
63 unsigned int m_ChannelsIndex;
64 unsigned int m_HeightIndex;
65 unsigned int m_WidthIndex;
66};
67
Ryan OShea2bbfaa72020-02-12 16:15:27 +000068/// Equality methods
Matteo Martincigh21350152018-11-28 16:22:22 +000069bool operator==(const armnn::DataLayout& dataLayout, const DataLayoutIndexed& indexed);
70bool operator==(const DataLayoutIndexed& indexed, const armnn::DataLayout& dataLayout);
71
72} // namespace armnnUtils