blob: a7579c8373139d9a6b8e40f1461a779cd5121a77 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#pragma once
7
8#include "RefWorkloadUtils.hpp"
9
10#include <armnn/Tensor.hpp>
11
12#include <cmath>
13
14namespace armnn
15{
16
17template<typename NormData>
18static void BatchNormImpl(NormData data,
19 const float* varIn,
20 const float* meanIn,
21 const float* gammaIn,
22 const float* betaIn,
23 float * outputData,
24 const float * inputData)
25{
26 const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]);
27 for (unsigned int c = 0; c < inputInfo0.GetShape()[1]; c++)
28 {
29 float var = varIn[c];
30 float mean = meanIn[c];
31 float gamma = gammaIn[c];
32 float beta = betaIn[c];
33
34 float mult = gamma / sqrtf(var + data.m_Parameters.m_Eps);
35 float add = beta - mult * mean;
36
37 for (unsigned int n = 0; n < inputInfo0.GetShape()[0]; n++)
38 {
39 for (unsigned int j = 0; j < inputInfo0.GetShape()[2]; j++)
40 {
41 for (unsigned int i = 0; i < inputInfo0.GetShape()[3]; i++)
42 {
43 unsigned int index = i +
44 j*inputInfo0.GetShape()[3] +
45 c*inputInfo0.GetShape()[3] * inputInfo0.GetShape()[2] +
46 n*inputInfo0.GetShape()[3] * inputInfo0.GetShape()[2]
47 * inputInfo0.GetShape()[1];
48
49 outputData[index] = mult * inputData[index] + add;
50 }
51 }
52 }
53 }
54}
55
56} //namespace armnn