blob: 03404bda5de53395799626f2d31b37daf1c32a7d [file] [log] [blame]
Matteo Martincigh21350152018-11-28 16:22:22 +00001//
Matteo Martincighe011d202019-11-28 11:35:47 +00002// Copyright © 2019 Arm Ltd. All rights reserved.
Matteo Martincigh21350152018-11-28 16:22:22 +00003// SPDX-License-Identifier: MIT
4//
Matteo Martincighee423ce2019-06-05 09:02:41 +01005
Matteo Martincigh21350152018-11-28 16:22:22 +00006#pragma once
Matteo Martincighee423ce2019-06-05 09:02:41 +01007
Matteo Martincigh21350152018-11-28 16:22:22 +00008#include <armnn/Types.hpp>
Matteo Martincighee423ce2019-06-05 09:02:41 +01009#include <armnn/Tensor.hpp>
Matteo Martincigh21350152018-11-28 16:22:22 +000010
Matteo Martincighf2aaab32019-06-06 15:46:22 +010011#include <boost/assert.hpp>
12
Matteo Martincigh21350152018-11-28 16:22:22 +000013namespace armnnUtils
14{
15
16// Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout
17class DataLayoutIndexed
18{
19public:
20 DataLayoutIndexed(armnn::DataLayout dataLayout);
21
22 armnn::DataLayout GetDataLayout() const { return m_DataLayout; }
23 unsigned int GetChannelsIndex() const { return m_ChannelsIndex; }
24 unsigned int GetHeightIndex() const { return m_HeightIndex; }
25 unsigned int GetWidthIndex() const { return m_WidthIndex; }
Matteo Martincighf2aaab32019-06-06 15:46:22 +010026
27 inline unsigned int GetIndex(const armnn::TensorShape& shape,
28 unsigned int batchIndex, unsigned int channelIndex,
29 unsigned int heightIndex, unsigned int widthIndex) const
30 {
31 BOOST_ASSERT( batchIndex < shape[0] || ( shape[0] == 0 && batchIndex == 0 ) );
32 BOOST_ASSERT( channelIndex < shape[m_ChannelsIndex] ||
33 ( shape[m_ChannelsIndex] == 0 && channelIndex == 0) );
34 BOOST_ASSERT( heightIndex < shape[m_HeightIndex] ||
35 ( shape[m_HeightIndex] == 0 && heightIndex == 0) );
36 BOOST_ASSERT( widthIndex < shape[m_WidthIndex] ||
37 ( shape[m_WidthIndex] == 0 && widthIndex == 0) );
38
39 // Offset the given indices appropriately depending on the data layout
40 switch (m_DataLayout)
41 {
42 case armnn::DataLayout::NHWC:
43 batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
44 heightIndex *= shape[m_WidthIndex] * shape[m_ChannelsIndex];
45 widthIndex *= shape[m_ChannelsIndex];
46 // channelIndex stays unchanged
47 break;
48 case armnn::DataLayout::NCHW:
49 default:
50 batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
51 channelIndex *= shape[m_HeightIndex] * shape[m_WidthIndex];
52 heightIndex *= shape[m_WidthIndex];
53 // widthIndex stays unchanged
54 break;
55 }
56
57 // Get the value using the correct offset
58 return batchIndex + channelIndex + heightIndex + widthIndex;
59 }
Matteo Martincigh21350152018-11-28 16:22:22 +000060
61private:
62 armnn::DataLayout m_DataLayout;
63 unsigned int m_ChannelsIndex;
64 unsigned int m_HeightIndex;
65 unsigned int m_WidthIndex;
66};
67
68// Equality methods
69bool operator==(const armnn::DataLayout& dataLayout, const DataLayoutIndexed& indexed);
70bool operator==(const DataLayoutIndexed& indexed, const armnn::DataLayout& dataLayout);
71
72} // namespace armnnUtils