blob: fbcb2fdf5a55855af1f34b0b11703fa04bb75576 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#pragma once
7
8#include "RefWorkloadUtils.hpp"
Matteo Martincigh8eb675e2018-10-17 14:43:29 +01009#include "TensorBufferArrayView.hpp"
telsoa014fcda012018-03-09 14:13:49 +000010
11#include <armnn/Tensor.hpp>
12
13#include <cmath>
14
15namespace armnn
16{
17
18template<typename NormData>
Matteo Martincigh8eb675e2018-10-17 14:43:29 +010019static void BatchNormImpl(NormData data,
telsoa014fcda012018-03-09 14:13:49 +000020 const float* varIn,
21 const float* meanIn,
22 const float* gammaIn,
23 const float* betaIn,
Matteo Martincigh8eb675e2018-10-17 14:43:29 +010024 float* outputData,
25 const float* inputData)
telsoa014fcda012018-03-09 14:13:49 +000026{
Matteo Martincigh8eb675e2018-10-17 14:43:29 +010027 const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]);
28 const TensorInfo& outputInfo = GetTensorInfo(data.m_Outputs[0]);
29
30 TensorBufferArrayView<const float> input(inputInfo.GetShape(),
31 inputData,
32 data.m_Parameters.m_DataLayout);
33 TensorBufferArrayView<float> output(outputInfo.GetShape(),
34 outputData,
35 data.m_Parameters.m_DataLayout);
36
37 DataLayoutIndexed dataLayout(data.m_Parameters.m_DataLayout);
38
39 for (unsigned int c = 0; c < inputInfo.GetShape()[dataLayout.GetChannelsIndex()]; c++)
telsoa014fcda012018-03-09 14:13:49 +000040 {
41 float var = varIn[c];
42 float mean = meanIn[c];
43 float gamma = gammaIn[c];
44 float beta = betaIn[c];
45
46 float mult = gamma / sqrtf(var + data.m_Parameters.m_Eps);
47 float add = beta - mult * mean;
48
Matteo Martincigh8eb675e2018-10-17 14:43:29 +010049 for (unsigned int n = 0; n < inputInfo.GetShape()[0]; n++)
telsoa014fcda012018-03-09 14:13:49 +000050 {
Matteo Martincigh8eb675e2018-10-17 14:43:29 +010051 for (unsigned int h = 0; h < inputInfo.GetShape()[dataLayout.GetHeightIndex()]; h++)
telsoa014fcda012018-03-09 14:13:49 +000052 {
Matteo Martincigh8eb675e2018-10-17 14:43:29 +010053 for (unsigned int w = 0; w < inputInfo.GetShape()[dataLayout.GetWidthIndex()]; w++)
telsoa014fcda012018-03-09 14:13:49 +000054 {
Matteo Martincigh8eb675e2018-10-17 14:43:29 +010055 output.Get(n, c, h, w) = mult * input.Get(n, c, h, w) + add;
telsoa014fcda012018-03-09 14:13:49 +000056 }
57 }
58 }
59 }
60}
61
62} //namespace armnn