blob: 963e3aa6f3d5fee4657348b5dc3ef7c938f291fb [file] [log] [blame]
Ruomei Yan25339c32019-05-28 16:48:20 +01001//
Colm Donelanb4ef1632024-02-01 15:00:43 +00002// Copyright © 2017, 2024 Arm Ltd. All rights reserved.
Ruomei Yan25339c32019-05-28 16:48:20 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "RefWorkloadUtils.hpp"
Colm Donelan0c479742021-12-10 12:43:54 +00007#include <armnn/backends/WorkloadData.hpp>
Ruomei Yan25339c32019-05-28 16:48:20 +01008#include <armnn/Tensor.hpp>
Ruomei Yan25339c32019-05-28 16:48:20 +01009#include "Splitter.hpp"
10
11#include <cmath>
12#include <limits>
13
14#include "Decoders.hpp"
15#include "Encoders.hpp"
16
17namespace armnn
18{
19
Finn Williamsb8181f72021-04-07 10:23:21 +010020void Split(const SplitterQueueDescriptor& data,
21 std::vector<ITensorHandle*> inputs,
22 std::vector<ITensorHandle*> outputs)
Ruomei Yan25339c32019-05-28 16:48:20 +010023{
Finn Williamsb8181f72021-04-07 10:23:21 +010024 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
Ruomei Yan25339c32019-05-28 16:48:20 +010025
26 std::unique_ptr<Decoder<float>> decoderPtr =
Finn Williamsb8181f72021-04-07 10:23:21 +010027 MakeDecoder<float>(inputInfo, inputs[0]->Map());
Ruomei Yan25339c32019-05-28 16:48:20 +010028 Decoder<float>& decoder = *decoderPtr;
29
30 for (unsigned int index = 0; index < inputInfo.GetNumElements(); ++index)
31 {
32 unsigned int indices[MaxNumOfTensorDimensions] = { 0 };
33
34 unsigned int indexRemainder = index;
35 unsigned int dimensionStride = inputInfo.GetNumElements();
36
37 for (unsigned int i = 0; i<inputInfo.GetNumDimensions(); i++)
38 {
39 dimensionStride /= inputInfo.GetShape()[i];
40 indices[i] = indexRemainder / dimensionStride; // Use integer division to round down.
41 indexRemainder -= indices[i] * dimensionStride;
42 }
43
44 for (unsigned int viewIdx = 0; viewIdx < data.m_ViewOrigins.size(); ++viewIdx)
45 {
46 SplitterQueueDescriptor::ViewOrigin const& view = data.m_ViewOrigins[viewIdx];
47
48 //Split view extents are defined by the size of (the corresponding) input tensor.
Finn Williamsb8181f72021-04-07 10:23:21 +010049 const TensorInfo& outputInfo = GetTensorInfo(outputs[viewIdx]);
Colm Donelanb4ef1632024-02-01 15:00:43 +000050 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(
51 outputInfo.GetNumDimensions() == inputInfo.GetNumDimensions(),
52 "The number of output dimensions does not match the number of input dimensions.");
Ruomei Yan25339c32019-05-28 16:48:20 +010053
54 // Check all dimensions to see if this element is inside the given input view.
55 bool insideView = true;
56 for (unsigned int i = 0; i<outputInfo.GetNumDimensions(); i++)
57 {
58 if (indices[i] < view.m_Origin[i])
59 {
60 insideView = false;
61 }
62 if (indices[i] >= view.m_Origin[i] + outputInfo.GetShape()[i])
63 {
64 insideView = false;
65 }
66 }
67
68 if (insideView)
69 {
70 std::unique_ptr<Encoder<float>> encoderPtr =
Finn Williamsb8181f72021-04-07 10:23:21 +010071 MakeEncoder<float>(outputInfo, outputs[viewIdx]->Map());
Ruomei Yan25339c32019-05-28 16:48:20 +010072 Encoder<float>& encoder = *encoderPtr;
73
74 unsigned int outIndex = 0;
75 unsigned int dimensionStride = 1;
76 float inputValue = 0.f;
77
78 for (unsigned int i = outputInfo.GetNumDimensions(); i-- > 0;)
79 {
80 outIndex += dimensionStride * (indices[i] - view.m_Origin[i]);
81 dimensionStride *= outputInfo.GetShape()[i];
82 }
83
84 decoder += index;
85 inputValue = decoder.Get();
86 decoder -= index;
87
88 encoder += outIndex;
89 encoder.Set(inputValue);
90 break;
91 }
92 }
93 }
94}
95
96}