blob: ebe9d2cfd52f8b3b056f2761e656cfa832edffac [file] [log] [blame]
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00001//
Teresa Charlinf77cab52023-06-01 16:15:13 +01002// Copyright © 2017-2020,2023 Arm Ltd and Contributors. All rights reserved.
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00003// SPDX-License-Identifier: MIT
4//
5
6#include "BatchToSpaceNd.hpp"
7
Teresa Charlinf77cab52023-06-01 16:15:13 +01008#include <armnnUtils/DataLayoutIndexed.hpp>
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00009
Matteo Martincigh21350152018-11-28 16:22:22 +000010using namespace armnnUtils;
11
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +000012namespace armnn
13{
14
Teresa Charlinf77cab52023-06-01 16:15:13 +010015unsigned int Offset(const TensorShape& shape,
16 unsigned int batch,
17 unsigned int height,
18 unsigned int width,
19 unsigned int channels,
20 const DataLayoutIndexed& dataLayout)
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +000021{
Teresa Charlinf77cab52023-06-01 16:15:13 +010022 // 3D Tensors
23 unsigned int channelDimension3D = dataLayout.GetDataLayout() == DataLayout::NCHW ? 1 : 2;
24 if (shape.GetNumDimensions() == 3)
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +000025 {
Teresa Charlinf77cab52023-06-01 16:15:13 +010026 return (batch * shape[dataLayout.GetHeightIndex()] + height) * shape[channelDimension3D] + channels;
27 }
28 // 4D Tensors
29 else if (shape.GetNumDimensions() == 4)
30 {
31 if (dataLayout.GetDataLayout() == DataLayout::NHWC)
32 {
33 return ((batch * shape[dataLayout.GetHeightIndex()] + height) *
34 shape[dataLayout.GetWidthIndex()] + width) *
35 shape[dataLayout.GetChannelsIndex()] + channels;
36 }
37 else
38 {
39 return ((batch * shape[dataLayout.GetChannelsIndex()] + channels) *
40 shape[dataLayout.GetHeightIndex()] + height) *
41 shape[dataLayout.GetWidthIndex()] + width;
42 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +000043 }
44 else
45 {
Teresa Charlinf77cab52023-06-01 16:15:13 +010046 throw InvalidArgumentException("Tensor rank must be either 3 or 4", CHECK_LOCATION());
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +000047 }
48}
49
Teresa Charlinf77cab52023-06-01 16:15:13 +010050void BatchToSpaceNd(const TensorInfo& inputInfo,
51 const TensorInfo& outputInfo,
52 const BatchToSpaceNdDescriptor& params,
53 Decoder<float>& inputData,
54 Encoder<float>& outputData)
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +000055{
Teresa Charlinf77cab52023-06-01 16:15:13 +010056 unsigned int rank = inputInfo.GetNumDimensions();
57 if (rank != 3 && rank != 4 )
58 {
59 throw InvalidArgumentException("Tensor rank must be either 3 or 4, but it is " + std::to_string(rank),
60 CHECK_LOCATION());
61 }
Éanna Ó Catháin95807ce2018-11-12 17:14:43 +000062
Teresa Charlinf77cab52023-06-01 16:15:13 +010063 DataLayoutIndexed dataLayout = params.m_DataLayout;
64 unsigned int channelDimension3D = params.m_DataLayout == DataLayout::NCHW ? 1 : 2;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +000065
Teresa Charlinf77cab52023-06-01 16:15:13 +010066 TensorShape inputShape = inputInfo.GetShape();
67 TensorShape outputShape = outputInfo.GetShape();
Éanna Ó Catháin95807ce2018-11-12 17:14:43 +000068
Teresa Charlinf77cab52023-06-01 16:15:13 +010069 const unsigned int inputBatchSize = inputShape[0];
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +000070 const unsigned int outputBatchSize = outputShape[0];
Teresa Charlinf77cab52023-06-01 16:15:13 +010071
72 const unsigned int channels = (rank == 3) ? inputShape[channelDimension3D]
73 : inputShape[dataLayout.GetChannelsIndex()];
74
75 const unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()];
76 const unsigned int inputWidth = (rank == 3) ? 1 : inputShape[dataLayout.GetWidthIndex()];
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +000077 const unsigned int outputHeight = outputShape[dataLayout.GetHeightIndex()];
Teresa Charlinf77cab52023-06-01 16:15:13 +010078 const unsigned int outputWidth = (rank == 3) ? 1 : outputShape[dataLayout.GetWidthIndex()];
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +000079
Teresa Charlinf77cab52023-06-01 16:15:13 +010080 const unsigned int blockHeight = params.m_BlockShape[0];
81 const unsigned int blockWidth = (rank == 3) ? 1 : params.m_BlockShape[1];
Éanna Ó Catháin95807ce2018-11-12 17:14:43 +000082
Teresa Charlinf77cab52023-06-01 16:15:13 +010083 const unsigned int cropsTop = params.m_Crops[0].first;
84 const unsigned int cropsLeft = (rank == 3) ? 0 : params.m_Crops[1].first;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +000085
86 for (unsigned int inBatch = 0; inBatch < inputBatchSize; ++inBatch)
87 {
88 const unsigned int outBatch = inBatch % outputBatchSize;
89 const unsigned int spatialOffset = inBatch / outputBatchSize;
90
Teresa Charlinf77cab52023-06-01 16:15:13 +010091 for (unsigned int inH = 0; inH < inputHeight; ++inH)
92 {
93 const unsigned int outH = inH * blockHeight + spatialOffset / blockWidth - cropsTop;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +000094
95 if (outH >= outputHeight)
96 {
97 continue;
98 }
99
Teresa Charlinf77cab52023-06-01 16:15:13 +0100100 for (unsigned int inW = 0; inW < inputWidth; ++inW)
101 {
102 const unsigned int outW = inW * blockWidth + spatialOffset % blockWidth - cropsLeft;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000103
104 if (outW >= outputWidth)
105 {
106 continue;
107 }
108
109 for (unsigned int c = 0; c < channels; c++)
110 {
111 unsigned int outOffset = Offset(outputShape, outBatch, outH, outW, c, dataLayout);
112 unsigned int inOffset = Offset(inputShape, inBatch, inH, inW, c, dataLayout);
Francis Murtagh47ea3c02019-06-20 12:07:19 +0100113
Teresa Charlinf77cab52023-06-01 16:15:13 +0100114 outputData[outOffset];
115 inputData[inOffset];
116 outputData.Set(inputData.Get());
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000117 }
118 }
119 }
120 }
121}
122
123} //namespace armnn