blob: 163d34b15934952e2781f7c3d4a21c0b8f9ec999 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Matteo Martincighe011d202019-11-28 11:35:47 +00002// Copyright © 2019 Arm Ltd. All rights reserved.
Matteo Martincigh21350152018-11-28 16:22:22 +00003// SPDX-License-Identifier: MIT
4//
Matteo Martincighee423ce2019-06-05 09:02:41 +01005
Matteo Martincigh21350152018-11-28 16:22:22 +00006#pragma once
Matteo Martincighee423ce2019-06-05 09:02:41 +01007
Matteo Martincigh21350152018-11-28 16:22:22 +00008#include <armnn/Types.hpp>
Matteo Martincighee423ce2019-06-05 09:02:41 +01009#include <armnn/Tensor.hpp>
Matteo Martincigh21350152018-11-28 16:22:22 +000010
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010011#include <armnn/utility/Assert.hpp>
Matteo Martincighf2aaab32019-06-06 15:46:22 +010012
Matteo Martincigh21350152018-11-28 16:22:22 +000013namespace armnnUtils
14{
15
Ryan OShea2bbfaa72020-02-12 16:15:27 +000016/// Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout
Matteo Martincigh21350152018-11-28 16:22:22 +000017class DataLayoutIndexed
18{
19public:
20 DataLayoutIndexed(armnn::DataLayout dataLayout);
21
22 armnn::DataLayout GetDataLayout() const { return m_DataLayout; }
23 unsigned int GetChannelsIndex() const { return m_ChannelsIndex; }
24 unsigned int GetHeightIndex() const { return m_HeightIndex; }
25 unsigned int GetWidthIndex() const { return m_WidthIndex; }
Matthew Sloyanb63a3112021-09-08 13:05:51 +010026 unsigned int GetDepthIndex() const { return m_DepthIndex; }
Matteo Martincighf2aaab32019-06-06 15:46:22 +010027
28 inline unsigned int GetIndex(const armnn::TensorShape& shape,
29 unsigned int batchIndex, unsigned int channelIndex,
30 unsigned int heightIndex, unsigned int widthIndex) const
31 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010032 ARMNN_ASSERT( batchIndex < shape[0] || ( shape[0] == 0 && batchIndex == 0 ) );
33 ARMNN_ASSERT( channelIndex < shape[m_ChannelsIndex] ||
Matteo Martincighf2aaab32019-06-06 15:46:22 +010034 ( shape[m_ChannelsIndex] == 0 && channelIndex == 0) );
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010035 ARMNN_ASSERT( heightIndex < shape[m_HeightIndex] ||
Matteo Martincighf2aaab32019-06-06 15:46:22 +010036 ( shape[m_HeightIndex] == 0 && heightIndex == 0) );
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010037 ARMNN_ASSERT( widthIndex < shape[m_WidthIndex] ||
Matteo Martincighf2aaab32019-06-06 15:46:22 +010038 ( shape[m_WidthIndex] == 0 && widthIndex == 0) );
39
Ryan OShea2bbfaa72020-02-12 16:15:27 +000040 /// Offset the given indices appropriately depending on the data layout
Matteo Martincighf2aaab32019-06-06 15:46:22 +010041 switch (m_DataLayout)
42 {
43 case armnn::DataLayout::NHWC:
44 batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
45 heightIndex *= shape[m_WidthIndex] * shape[m_ChannelsIndex];
46 widthIndex *= shape[m_ChannelsIndex];
Ryan OShea2bbfaa72020-02-12 16:15:27 +000047 /// channelIndex stays unchanged
Matteo Martincighf2aaab32019-06-06 15:46:22 +010048 break;
49 case armnn::DataLayout::NCHW:
50 default:
51 batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
52 channelIndex *= shape[m_HeightIndex] * shape[m_WidthIndex];
53 heightIndex *= shape[m_WidthIndex];
Ryan OShea2bbfaa72020-02-12 16:15:27 +000054 /// widthIndex stays unchanged
Matteo Martincighf2aaab32019-06-06 15:46:22 +010055 break;
56 }
57
Ryan OShea2bbfaa72020-02-12 16:15:27 +000058 /// Get the value using the correct offset
Matteo Martincighf2aaab32019-06-06 15:46:22 +010059 return batchIndex + channelIndex + heightIndex + widthIndex;
60 }
Matteo Martincigh21350152018-11-28 16:22:22 +000061
62private:
63 armnn::DataLayout m_DataLayout;
64 unsigned int m_ChannelsIndex;
65 unsigned int m_HeightIndex;
66 unsigned int m_WidthIndex;
Matthew Sloyanb63a3112021-09-08 13:05:51 +010067 unsigned int m_DepthIndex;
Matteo Martincigh21350152018-11-28 16:22:22 +000068};
69
Ryan OShea2bbfaa72020-02-12 16:15:27 +000070/// Equality methods
Matteo Martincigh21350152018-11-28 16:22:22 +000071bool operator==(const armnn::DataLayout& dataLayout, const DataLayoutIndexed& indexed);
72bool operator==(const DataLayoutIndexed& indexed, const armnn::DataLayout& dataLayout);
73
74} // namespace armnnUtils