blob: 799e7a327b82458178cea74f85fa368e45ff8de5 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#pragma once
7
8#include "RefWorkloadUtils.hpp"
Matteo Martincigh8eb675e2018-10-17 14:43:29 +01009#include "TensorBufferArrayView.hpp"
telsoa014fcda012018-03-09 14:13:49 +000010
11#include <armnn/Tensor.hpp>
12
Matteo Martincigh21350152018-11-28 16:22:22 +000013#include <DataLayoutIndexed.hpp>
14
telsoa014fcda012018-03-09 14:13:49 +000015#include <cmath>
16
17namespace armnn
18{
19
20template<typename NormData>
Matteo Martincigh8eb675e2018-10-17 14:43:29 +010021static void BatchNormImpl(NormData data,
telsoa014fcda012018-03-09 14:13:49 +000022 const float* varIn,
23 const float* meanIn,
24 const float* gammaIn,
25 const float* betaIn,
Matteo Martincigh8eb675e2018-10-17 14:43:29 +010026 float* outputData,
27 const float* inputData)
telsoa014fcda012018-03-09 14:13:49 +000028{
Matteo Martincigh8eb675e2018-10-17 14:43:29 +010029 const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]);
30 const TensorInfo& outputInfo = GetTensorInfo(data.m_Outputs[0]);
31
32 TensorBufferArrayView<const float> input(inputInfo.GetShape(),
33 inputData,
34 data.m_Parameters.m_DataLayout);
35 TensorBufferArrayView<float> output(outputInfo.GetShape(),
36 outputData,
37 data.m_Parameters.m_DataLayout);
38
Matteo Martincigh21350152018-11-28 16:22:22 +000039 armnnUtils::DataLayoutIndexed dataLayout(data.m_Parameters.m_DataLayout);
Matteo Martincigh8eb675e2018-10-17 14:43:29 +010040
41 for (unsigned int c = 0; c < inputInfo.GetShape()[dataLayout.GetChannelsIndex()]; c++)
telsoa014fcda012018-03-09 14:13:49 +000042 {
43 float var = varIn[c];
44 float mean = meanIn[c];
45 float gamma = gammaIn[c];
46 float beta = betaIn[c];
47
48 float mult = gamma / sqrtf(var + data.m_Parameters.m_Eps);
49 float add = beta - mult * mean;
50
Matteo Martincigh8eb675e2018-10-17 14:43:29 +010051 for (unsigned int n = 0; n < inputInfo.GetShape()[0]; n++)
telsoa014fcda012018-03-09 14:13:49 +000052 {
Matteo Martincigh8eb675e2018-10-17 14:43:29 +010053 for (unsigned int h = 0; h < inputInfo.GetShape()[dataLayout.GetHeightIndex()]; h++)
telsoa014fcda012018-03-09 14:13:49 +000054 {
Matteo Martincigh8eb675e2018-10-17 14:43:29 +010055 for (unsigned int w = 0; w < inputInfo.GetShape()[dataLayout.GetWidthIndex()]; w++)
telsoa014fcda012018-03-09 14:13:49 +000056 {
Matteo Martincigh8eb675e2018-10-17 14:43:29 +010057 output.Get(n, c, h, w) = mult * input.Get(n, c, h, w) + add;
telsoa014fcda012018-03-09 14:13:49 +000058 }
59 }
60 }
61 }
62}
63
64} //namespace armnn