blob: bedf8418ef983c493d8adb58256e5b298ad87cf1 [file] [log] [blame]
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "BatchToSpaceNd.hpp"
7
8#include "RefWorkloadUtils.hpp"
9
10#include <armnn/Types.hpp>
11
12#include <boost/assert.hpp>
13
14namespace armnn
15{
16
17inline unsigned int Offset(const TensorShape& shape, unsigned int batch, unsigned int height, unsigned int width,
18 unsigned int channels, const DataLayoutIndexed& dataLayout)
19{
20 if (dataLayout.GetDataLayout() == DataLayout::NHWC)
21 {
22 return ((batch * shape[dataLayout.GetHeightIndex()] + height) * shape[dataLayout.GetWidthIndex()] + width) *
23 shape[dataLayout.GetChannelsIndex()] + channels;
24 }
25 else
26 {
27 return ((batch * shape[dataLayout.GetChannelsIndex()] + channels) *
28 shape[dataLayout.GetHeightIndex()] + height) *
29 shape[dataLayout.GetWidthIndex()] + width;
30 }
31}
32
33void BatchToSpaceNd(const DataLayoutIndexed& dataLayout,
34 const TensorInfo& inputTensorInfo,
35 const TensorInfo& outputTensorInfo,
36 const std::vector<unsigned int>& blockShape,
37 const std::vector<std::vector<unsigned int>>& cropsData,
38 const float* inputData,
39 float* outputData)
40{
41 TensorShape inputShape = inputTensorInfo.GetShape();
42 unsigned int inputNumDims = inputShape.GetNumDimensions();
43 if (inputNumDims != 4)
44 {
45 throw armnn::InvalidArgumentException("Expected Input with 4 Dimensions");
46 }
47
48 TensorShape outputShape = outputTensorInfo.GetShape();
49 unsigned int outputNumDims = outputShape.GetNumDimensions();
50 if (outputNumDims != 4)
51 {
52 throw armnn::InvalidArgumentException("Expected Output with 4 Dimensions");
53 }
54
55 const unsigned int inputBatchSize = inputShape[0];
56 const unsigned int channels = inputShape[dataLayout.GetChannelsIndex()];
57
58 const unsigned int outputBatchSize = outputShape[0];
59 const unsigned int outputHeight = outputShape[dataLayout.GetHeightIndex()];
60 const unsigned int outputWidth = outputShape[dataLayout.GetWidthIndex()];
61
62 const unsigned int blockShapeHeight = blockShape[0];
63 const unsigned int blockShapeWidth = blockShape[1];
64
65 const unsigned int cropsTop = cropsData[0][0];
66 const unsigned int cropsLeft = cropsData[1][0];
67
68 for (unsigned int inBatch = 0; inBatch < inputBatchSize; ++inBatch)
69 {
70 const unsigned int outBatch = inBatch % outputBatchSize;
71 const unsigned int spatialOffset = inBatch / outputBatchSize;
72
73 for (unsigned int inH = 0; inH < inputTensorInfo.GetShape()[dataLayout.GetHeightIndex()]; ++inH) {
74 const unsigned int outH = inH * blockShapeHeight + spatialOffset / blockShapeWidth - cropsTop;
75
76 if (outH >= outputHeight)
77 {
78 continue;
79 }
80
81 for (unsigned int inW = 0; inW < inputTensorInfo.GetShape()[dataLayout.GetWidthIndex()]; ++inW) {
82 const unsigned int outW = inW * blockShapeWidth + spatialOffset % blockShapeWidth - cropsLeft;
83
84 if (outW >= outputWidth)
85 {
86 continue;
87 }
88
89 for (unsigned int c = 0; c < channels; c++)
90 {
91 unsigned int outOffset = Offset(outputShape, outBatch, outH, outW, c, dataLayout);
92 unsigned int inOffset = Offset(inputShape, inBatch, inH, inW, c, dataLayout);
93 outputData[outOffset] = inputData[inOffset];
94 }
95 }
96 }
97 }
98}
99
100} //namespace armnn