blob: 3bddfb0cab57a592a845443180ca913a7bf67153 [file] [log] [blame]
Ruomei Yan25339c32019-05-28 16:48:20 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefWorkloadUtils.hpp"
7#include <backendsCommon/WorkloadData.hpp>
8#include <armnn/Tensor.hpp>
9
10#include <boost/assert.hpp>
11#include "Splitter.hpp"
12
13#include <cmath>
14#include <limits>
15
16#include "Decoders.hpp"
17#include "Encoders.hpp"
18
19namespace armnn
20{
21
22void Split(const SplitterQueueDescriptor& data)
23{
24 const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]);
25
26 std::unique_ptr<Decoder<float>> decoderPtr =
27 MakeDecoder<float>(inputInfo, data.m_Inputs[0]->Map());
28 Decoder<float>& decoder = *decoderPtr;
29
30 for (unsigned int index = 0; index < inputInfo.GetNumElements(); ++index)
31 {
32 unsigned int indices[MaxNumOfTensorDimensions] = { 0 };
33
34 unsigned int indexRemainder = index;
35 unsigned int dimensionStride = inputInfo.GetNumElements();
36
37 for (unsigned int i = 0; i<inputInfo.GetNumDimensions(); i++)
38 {
39 dimensionStride /= inputInfo.GetShape()[i];
40 indices[i] = indexRemainder / dimensionStride; // Use integer division to round down.
41 indexRemainder -= indices[i] * dimensionStride;
42 }
43
44 for (unsigned int viewIdx = 0; viewIdx < data.m_ViewOrigins.size(); ++viewIdx)
45 {
46 SplitterQueueDescriptor::ViewOrigin const& view = data.m_ViewOrigins[viewIdx];
47
48 //Split view extents are defined by the size of (the corresponding) input tensor.
49 const TensorInfo& outputInfo = GetTensorInfo(data.m_Outputs[viewIdx]);
50 BOOST_ASSERT(outputInfo.GetNumDimensions() == inputInfo.GetNumDimensions());
51
52 // Check all dimensions to see if this element is inside the given input view.
53 bool insideView = true;
54 for (unsigned int i = 0; i<outputInfo.GetNumDimensions(); i++)
55 {
56 if (indices[i] < view.m_Origin[i])
57 {
58 insideView = false;
59 }
60 if (indices[i] >= view.m_Origin[i] + outputInfo.GetShape()[i])
61 {
62 insideView = false;
63 }
64 }
65
66 if (insideView)
67 {
68 std::unique_ptr<Encoder<float>> encoderPtr =
69 MakeEncoder<float>(outputInfo, data.m_Outputs[viewIdx]->Map());
70 Encoder<float>& encoder = *encoderPtr;
71
72 unsigned int outIndex = 0;
73 unsigned int dimensionStride = 1;
74 float inputValue = 0.f;
75
76 for (unsigned int i = outputInfo.GetNumDimensions(); i-- > 0;)
77 {
78 outIndex += dimensionStride * (indices[i] - view.m_Origin[i]);
79 dimensionStride *= outputInfo.GetShape()[i];
80 }
81
82 decoder += index;
83 inputValue = decoder.Get();
84 decoder -= index;
85
86 encoder += outIndex;
87 encoder.Set(inputValue);
88 break;
89 }
90 }
91 }
92}
93
94}