blob: adf452dde19dbb6da5e4609acf92363cfb53231b [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01006#include "RefNormalizationWorkload.hpp"
telsoa014fcda012018-03-09 14:13:49 +00007#include "RefWorkloadUtils.hpp"
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01008#include "Decoders.hpp"
9#include "Encoders.hpp"
telsoa014fcda012018-03-09 14:13:49 +000010
11#include <armnn/Tensor.hpp>
12
Matteo Martincighe011d202019-11-28 11:35:47 +000013#include <armnnUtils/DataLayoutIndexed.hpp>
14
Matteo Martincigh2fc70c52019-06-05 14:12:48 +010015#include <Profiling.hpp>
16
telsoa014fcda012018-03-09 14:13:49 +000017#include <boost/numeric/conversion/cast.hpp>
18
Matteo Martincigh2fc70c52019-06-05 14:12:48 +010019using namespace armnn;
Matteo Martincigh21350152018-11-28 16:22:22 +000020using namespace armnnUtils;
21
Matteo Martincigh2fc70c52019-06-05 14:12:48 +010022namespace
telsoa014fcda012018-03-09 14:13:49 +000023{
24
telsoa01c577f2c2018-08-31 09:22:23 +010025// Helper function to compute "Within" normalization using Krichevsky 2012: Local Brightness Normalization.
Matteo Martincigh2fc70c52019-06-05 14:12:48 +010026void NormalizeWithinUingLbr(Decoder<float>& inputData,
27 Encoder<float>& outputData,
28 const TensorShape& tensorShape,
29 uint32_t norm_size,
30 float alpha,
31 float beta,
32 float kappa)
telsoa014fcda012018-03-09 14:13:49 +000033{
34 const unsigned int batchSize = tensorShape[0];
35 const unsigned int depth = tensorShape[1];
36 const unsigned int rows = tensorShape[2];
37 const unsigned int cols = tensorShape[3];
38
39 int radius = boost::numeric_cast<int>(norm_size / 2u); /* Strong Assumption on rounding Mode */
40
41 for (unsigned int n = 0; n < batchSize; n++)
42 {
43 for (unsigned int c = 0; c < depth; c++)
44 {
45 for (unsigned int h = 0; h < rows; h++)
46 {
47 for (unsigned int w = 0; w < cols; w++)
48 {
49 float accumulated_scale = 0.0;
50 for (int y = -radius; y <= radius; y++)
51 {
52 for (int x = -radius; x <= radius; x++)
53 {
54 int i = boost::numeric_cast<int>(w) + x;
55 int j = boost::numeric_cast<int>(h) + y;
56
57 if ((i < 0) || (i >= boost::numeric_cast<int>(cols)))
58 {
59 continue;
60 }
61
62 if ((j < 0) || (j >= boost::numeric_cast<int>(rows)))
63 {
64 continue;
65 }
66
Matteo Martincigh2fc70c52019-06-05 14:12:48 +010067 unsigned int inputIndex = n * cols * rows * depth +
68 c * cols * rows +
69 boost::numeric_cast<unsigned int>(j) * cols +
70 boost::numeric_cast<unsigned int>(i);
71 inputData[inputIndex];
72 float inval = inputData.Get();
telsoa014fcda012018-03-09 14:13:49 +000073
74 accumulated_scale += inval*inval;
75 }
76 }
Matteo Martincigh2fc70c52019-06-05 14:12:48 +010077
78 unsigned int index = n * cols * rows * depth +
79 c * cols * rows +
80 h * cols +
81 w;
82 inputData[index];
83 outputData[index];
84 outputData.Set(inputData.Get() / (powf((kappa + (accumulated_scale * alpha)), beta)));
telsoa014fcda012018-03-09 14:13:49 +000085 }
86 }
87 }
88 }
89}
90
telsoa01c577f2c2018-08-31 09:22:23 +010091// Helper function to compute "Across" normalization using Krichevsky 2012: Local Brightness Normalization.
Matteo Martincigh2fc70c52019-06-05 14:12:48 +010092void NormalizeAcrossUingLbr(Decoder<float>& inputData,
93 Encoder<float>& outputData,
telsoa014fcda012018-03-09 14:13:49 +000094 const TensorShape& tensorShape,
95 uint32_t norm_size,
96 float alpha,
97 float beta,
Matteo Martincigh8e6f92d2018-10-18 08:45:39 +010098 float kappa,
99 DataLayout dataLayout)
telsoa014fcda012018-03-09 14:13:49 +0000100{
Matteo Martincigh8e6f92d2018-10-18 08:45:39 +0100101 DataLayoutIndexed dataLayoutIndexed(dataLayout);
102
telsoa014fcda012018-03-09 14:13:49 +0000103 const unsigned int batchSize = tensorShape[0];
Matteo Martincigh8e6f92d2018-10-18 08:45:39 +0100104 const unsigned int depth = tensorShape[dataLayoutIndexed.GetChannelsIndex()];
105 const unsigned int rows = tensorShape[dataLayoutIndexed.GetHeightIndex()];
106 const unsigned int cols = tensorShape[dataLayoutIndexed.GetWidthIndex()];
telsoa014fcda012018-03-09 14:13:49 +0000107
108 int radius = boost::numeric_cast<int>(norm_size / 2u); /* Strong Assumption on rounding Mode */
109
110 for (unsigned int n = 0; n < batchSize; n++)
111 {
112 for (unsigned int c = 0; c < depth; c++)
113 {
114 for (unsigned int h = 0; h < rows; h++)
115 {
116 for (unsigned int w = 0; w < cols; w++)
117 {
118 float accumulated_scale = 0.0;
119 for (int z = -radius; z <= radius; z++)
120 {
121 int k = boost::numeric_cast<int>(c) + z;
122
123 if ((k < 0) || (k >= boost::numeric_cast<int>(depth)))
124 {
125 continue;
126 }
127
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100128 unsigned inputIndex = dataLayoutIndexed.GetIndex(tensorShape,
129 n,
130 boost::numeric_cast<unsigned int>(k),
131 h,
132 w);
133
134 inputData[inputIndex];
135 float inval = inputData.Get();
telsoa014fcda012018-03-09 14:13:49 +0000136
Matteo Martincigh8e6f92d2018-10-18 08:45:39 +0100137 accumulated_scale += inval * inval;
telsoa014fcda012018-03-09 14:13:49 +0000138 }
Matteo Martincigh8e6f92d2018-10-18 08:45:39 +0100139
telsoa014fcda012018-03-09 14:13:49 +0000140 float scale = kappa + (accumulated_scale * alpha);
141 scale = powf(scale, -beta);
Matteo Martincigh8e6f92d2018-10-18 08:45:39 +0100142
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100143 unsigned index = dataLayoutIndexed.GetIndex(tensorShape, n, c, h, w);
144
145 inputData[index];
146 outputData[index];
147 outputData.Set(scale * inputData.Get());
telsoa014fcda012018-03-09 14:13:49 +0000148 }
149 }
150 }
151 }
152}
153
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100154} // Anonymous namespace
155
156namespace armnn
telsoa014fcda012018-03-09 14:13:49 +0000157{
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100158
159RefNormalizationWorkload::RefNormalizationWorkload(const NormalizationQueueDescriptor& descriptor,
160 const WorkloadInfo& info)
161 : BaseWorkload(descriptor, info)
162{}
163
164void RefNormalizationWorkload::Execute() const
165{
166 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefNormalizationWorkload_Execute");
telsoa014fcda012018-03-09 14:13:49 +0000167
168 const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
169
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100170 auto inputDecoder = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map());
171 auto outputEncoder = MakeEncoder<float>(inputInfo, m_Data.m_Outputs[0]->Map());
telsoa014fcda012018-03-09 14:13:49 +0000172
telsoa014fcda012018-03-09 14:13:49 +0000173 if (NormalizationAlgorithmMethod::LocalBrightness == m_Data.m_Parameters.m_NormMethodType)
174 {
175 if (NormalizationAlgorithmChannel::Within == m_Data.m_Parameters.m_NormChannelType)
176 {
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100177 NormalizeWithinUingLbr(*inputDecoder,
178 *outputEncoder,
telsoa014fcda012018-03-09 14:13:49 +0000179 inputInfo.GetShape(),
180 m_Data.m_Parameters.m_NormSize,
181 m_Data.m_Parameters.m_Alpha,
182 m_Data.m_Parameters.m_Beta,
183 m_Data.m_Parameters.m_K);
184 }
185 else if (NormalizationAlgorithmChannel::Across == m_Data.m_Parameters.m_NormChannelType)
186 {
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100187 NormalizeAcrossUingLbr(*inputDecoder,
188 *outputEncoder,
telsoa014fcda012018-03-09 14:13:49 +0000189 inputInfo.GetShape(),
190 m_Data.m_Parameters.m_NormSize,
191 m_Data.m_Parameters.m_Alpha,
192 m_Data.m_Parameters.m_Beta,
Matteo Martincigh8e6f92d2018-10-18 08:45:39 +0100193 m_Data.m_Parameters.m_K,
194 m_Data.m_Parameters.m_DataLayout);
telsoa014fcda012018-03-09 14:13:49 +0000195 }
196 else
197 {
Derek Lamberti08446972019-11-26 16:38:31 +0000198 ARMNN_LOG(warning) << "Illegal NORMALIZATION mode in normalization_f32";
telsoa014fcda012018-03-09 14:13:49 +0000199 return;
200 }
201 }
202 else
203 {
Derek Lamberti08446972019-11-26 16:38:31 +0000204 ARMNN_LOG(warning) << "Lcr method (Jarret 2009: Local Contrast Normalization) not supported yet.";
telsoa014fcda012018-03-09 14:13:49 +0000205 return;
206 }
207}
208
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100209} // namespace armnn