blob: e57cec531f63819a06b26455a023844e2e3de235 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Ryan OSheac229b3f2023-06-27 22:34:54 +01002// Copyright © 2018-2021,2023 Arm Ltd. All rights reserved.
Matteo Martincigh21350152018-11-28 16:22:22 +00003// SPDX-License-Identifier: MIT
4//
Matteo Martincighee423ce2019-06-05 09:02:41 +01005
Matteo Martincigh21350152018-11-28 16:22:22 +00006#pragma once
Matteo Martincighee423ce2019-06-05 09:02:41 +01007
Matteo Martincigh21350152018-11-28 16:22:22 +00008#include <armnn/Types.hpp>
Matteo Martincighee423ce2019-06-05 09:02:41 +01009#include <armnn/Tensor.hpp>
Matteo Martincigh21350152018-11-28 16:22:22 +000010
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010011#include <armnn/utility/Assert.hpp>
Matteo Martincighf2aaab32019-06-06 15:46:22 +010012
Matteo Martincigh21350152018-11-28 16:22:22 +000013namespace armnnUtils
14{
15
Ryan OShea2bbfaa72020-02-12 16:15:27 +000016/// Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout
Matteo Martincigh21350152018-11-28 16:22:22 +000017class DataLayoutIndexed
18{
19public:
20 DataLayoutIndexed(armnn::DataLayout dataLayout);
21
22 armnn::DataLayout GetDataLayout() const { return m_DataLayout; }
23 unsigned int GetChannelsIndex() const { return m_ChannelsIndex; }
24 unsigned int GetHeightIndex() const { return m_HeightIndex; }
25 unsigned int GetWidthIndex() const { return m_WidthIndex; }
Matthew Sloyanb63a3112021-09-08 13:05:51 +010026 unsigned int GetDepthIndex() const { return m_DepthIndex; }
Matteo Martincighf2aaab32019-06-06 15:46:22 +010027
28 inline unsigned int GetIndex(const armnn::TensorShape& shape,
29 unsigned int batchIndex, unsigned int channelIndex,
30 unsigned int heightIndex, unsigned int widthIndex) const
31 {
Ryan OSheac229b3f2023-06-27 22:34:54 +010032 if (batchIndex >= shape[0] && !( shape[0] == 0 && batchIndex == 0))
33 {
34 throw armnn::Exception("Unable to get batch index", CHECK_LOCATION());
35 }
36 if (channelIndex >= shape[m_ChannelsIndex] &&
37 !(shape[m_ChannelsIndex] == 0 && channelIndex == 0))
38 {
39 throw armnn::Exception("Unable to get channel index", CHECK_LOCATION());
40
41 }
42 if (heightIndex >= shape[m_HeightIndex] &&
43 !( shape[m_HeightIndex] == 0 && heightIndex == 0))
44 {
45 throw armnn::Exception("Unable to get height index", CHECK_LOCATION());
46 }
47 if (widthIndex >= shape[m_WidthIndex] &&
48 ( shape[m_WidthIndex] == 0 && widthIndex == 0))
49 {
50 throw armnn::Exception("Unable to get width index", CHECK_LOCATION());
51 }
Matteo Martincighf2aaab32019-06-06 15:46:22 +010052
Ryan OShea2bbfaa72020-02-12 16:15:27 +000053 /// Offset the given indices appropriately depending on the data layout
Matteo Martincighf2aaab32019-06-06 15:46:22 +010054 switch (m_DataLayout)
55 {
56 case armnn::DataLayout::NHWC:
57 batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
58 heightIndex *= shape[m_WidthIndex] * shape[m_ChannelsIndex];
59 widthIndex *= shape[m_ChannelsIndex];
Ryan OShea2bbfaa72020-02-12 16:15:27 +000060 /// channelIndex stays unchanged
Matteo Martincighf2aaab32019-06-06 15:46:22 +010061 break;
62 case armnn::DataLayout::NCHW:
63 default:
64 batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
65 channelIndex *= shape[m_HeightIndex] * shape[m_WidthIndex];
66 heightIndex *= shape[m_WidthIndex];
Ryan OShea2bbfaa72020-02-12 16:15:27 +000067 /// widthIndex stays unchanged
Matteo Martincighf2aaab32019-06-06 15:46:22 +010068 break;
69 }
70
Ryan OShea2bbfaa72020-02-12 16:15:27 +000071 /// Get the value using the correct offset
Matteo Martincighf2aaab32019-06-06 15:46:22 +010072 return batchIndex + channelIndex + heightIndex + widthIndex;
73 }
Matteo Martincigh21350152018-11-28 16:22:22 +000074
75private:
76 armnn::DataLayout m_DataLayout;
77 unsigned int m_ChannelsIndex;
78 unsigned int m_HeightIndex;
79 unsigned int m_WidthIndex;
Matthew Sloyanb63a3112021-09-08 13:05:51 +010080 unsigned int m_DepthIndex;
Matteo Martincigh21350152018-11-28 16:22:22 +000081};
82
Ryan OShea2bbfaa72020-02-12 16:15:27 +000083/// Equality methods
Matteo Martincigh21350152018-11-28 16:22:22 +000084bool operator==(const armnn::DataLayout& dataLayout, const DataLayoutIndexed& indexed);
85bool operator==(const DataLayoutIndexed& indexed, const armnn::DataLayout& dataLayout);
86
87} // namespace armnnUtils