blob: 8877ee228465f8d870240b43dcda9d4b1bb720f6 [file] [log] [blame]
Ferran Balaguerb2845652019-02-27 09:42:06 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Merger.hpp"
7#include "RefWorkloadUtils.hpp"
8
9namespace armnn
10{
11
12template <>
13void CopyValue<float>(const float& source, const TensorInfo& sourceInfo, float& dest, const TensorInfo& destInfo)
14{
15 dest = source;
16}
17
18template <>
19void CopyValue<uint8_t>(const uint8_t& source, const TensorInfo& sourceInfo, uint8_t& dest, const TensorInfo& destInfo)
20{
21 if (sourceInfo.GetQuantizationScale() != destInfo.GetQuantizationScale() ||
22 sourceInfo.GetQuantizationOffset() != destInfo.GetQuantizationOffset())
23 {
Jim Flynn18ce3382019-03-08 11:08:30 +000024 // Dequantize value according to sourceInfo params
Ferran Balaguerb2845652019-02-27 09:42:06 +000025 float dequantizedValue = armnn::Dequantize<uint8_t>(source,
26 sourceInfo.GetQuantizationScale(),
27 sourceInfo.GetQuantizationOffset());
28
29 // Quantize again according to destInfo paramns
30 dest = armnn::Quantize<uint8_t>(dequantizedValue,
31 destInfo.GetQuantizationScale(),
32 destInfo.GetQuantizationOffset());
33 }
34 else
35 {
36 dest = source;
37 }
38}
39
40template <typename DataType>
41void Merger(const MergerQueueDescriptor& data)
42{
43 const TensorInfo& outputInfo0 = GetTensorInfo(data.m_Outputs[0]);
44
45 for (unsigned int index = 0 ; index < outputInfo0.GetNumElements(); ++index)
46 {
47 unsigned int indices[MaxNumOfTensorDimensions] = { 0 };
48
49 unsigned int indexRemainder = index;
50 unsigned int dimensionStride = outputInfo0.GetNumElements();
51
52 for (unsigned int i = 0; i < outputInfo0.GetNumDimensions(); i++)
53 {
54 dimensionStride /= outputInfo0.GetShape()[i];
55 indices[i] = indexRemainder / dimensionStride; // Use integer division to round down.
56 indexRemainder -= indices[i] * dimensionStride;
57 }
58
59 for (unsigned int viewIdx = 0; viewIdx < data.m_ViewOrigins.size(); ++viewIdx)
60 {
61 MergerQueueDescriptor::ViewOrigin const& view = data.m_ViewOrigins[viewIdx];
62
63 //Split view extents are defined by the size of (the corresponding) input tensor.
64 const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[viewIdx]);
65 BOOST_ASSERT(inputInfo.GetNumDimensions() == outputInfo0.GetNumDimensions());
66
67 // Check all dimensions to see if this element is inside the given input view.
68 bool insideView = true;
69 for (unsigned int i = 0; i < inputInfo.GetNumDimensions(); i++)
70 {
71 if (indices[i] < view.m_Origin[i])
72 {
73 insideView = false;
74 }
75 if (indices[i] >= view.m_Origin[i] + inputInfo.GetShape()[i])
76 {
77 insideView = false;
78 }
79 }
80
81 if (insideView)
82 {
83 unsigned int inIndex = 0;
84 unsigned int dimensionStride = 1;
85
86 for (unsigned int i = inputInfo.GetNumDimensions(); i-- > 0;)
87 {
88 inIndex += dimensionStride * (indices[i] - view.m_Origin[i]);
89 dimensionStride *= inputInfo.GetShape()[i];
90 }
91
92 CopyValue<DataType>((GetInputTensorData<DataType>(viewIdx, data))[inIndex],
93 GetTensorInfo(data.m_Inputs[viewIdx]),
94 (GetOutputTensorData<DataType>(0, data))[index],
95 outputInfo0);
96
97 //What should we do if input views overlap on the output tensor?
98 //We could error, take the average, or shm else...
99 //For now just stop after finding first view (input) that matches.
100 break;
101 }
102 }
103 }
104}
105
106template void Merger<float>(const MergerQueueDescriptor& data);
107
108template void Merger<uint8_t>(const MergerQueueDescriptor& data);
109
110} //namespace armnn