blob: 82e6e8674792779f7f4740652c398a583a8b61ed [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#pragma once
6
7#include <armnn/ArmNN.hpp>
8#include <armnn/Tensor.hpp>
9#include <backends/WorkloadInfo.hpp>
10
11#include "test/TensorHelpers.hpp"
12
13#include "backends/CpuTensorHandle.hpp"
14#include "backends/WorkloadFactory.hpp"
15
16#include "backends/test/QuantizeHelper.hpp"
17
18
19template<typename T>
20LayerTestResult<T,4> BatchNormTestImpl(armnn::IWorkloadFactory& workloadFactory,
21 float qScale,
22 int32_t qOffset)
23{
24 const unsigned int width = 2;
25 const unsigned int height = 3;
26 const unsigned int channels = 2;
27 const unsigned int num = 1;
28
29 armnn::TensorInfo inputTensorInfo({num, channels, height, width}, armnn::GetDataType<T>());
30 armnn::TensorInfo outputTensorInfo({num, channels, height, width}, armnn::GetDataType<T>());
31 armnn::TensorInfo tensorInfo({channels}, armnn::GetDataType<T>());
32
33 // Set quantization parameters if the requested type is a quantized type.
34 if(armnn::IsQuantizedType<T>())
35 {
36 inputTensorInfo.SetQuantizationScale(qScale);
37 inputTensorInfo.SetQuantizationOffset(qOffset);
38 outputTensorInfo.SetQuantizationScale(qScale);
39 outputTensorInfo.SetQuantizationOffset(qOffset);
40 tensorInfo.SetQuantizationScale(qScale);
41 tensorInfo.SetQuantizationOffset(qOffset);
42 }
43
44 auto input = MakeTensor<T, 4>(inputTensorInfo,
45 QuantizedVector<T>(qScale, qOffset,
46 {
47 1.f, 4.f,
48 4.f, 2.f,
49 1.f, 6.f,
50
51 1.f, 1.f,
52 4.f, 1.f,
53 -2.f, 4.f
54 }));
telsoa01c577f2c2018-08-31 09:22:23 +010055 // These values are per-channel of the input.
telsoa014fcda012018-03-09 14:13:49 +000056 auto mean = MakeTensor<T, 1>(tensorInfo, QuantizedVector<T>(qScale, qOffset, {3, -2}));
57 auto variance = MakeTensor<T, 1>(tensorInfo, QuantizedVector<T>(qScale, qOffset, {4, 9}));
58 auto beta = MakeTensor<T, 1>(tensorInfo, QuantizedVector<T>(qScale, qOffset, {3, 2}));
59 auto gamma = MakeTensor<T, 1>(tensorInfo, QuantizedVector<T>(qScale, qOffset, {2, 1}));
60 LayerTestResult<T,4> ret(outputTensorInfo);
61
62 std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
63 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
64
65 armnn::BatchNormalizationQueueDescriptor data;
66 armnn::WorkloadInfo info;
67 armnn::ScopedCpuTensorHandle meanTensor(tensorInfo);
68 armnn::ScopedCpuTensorHandle varianceTensor(tensorInfo);
69 armnn::ScopedCpuTensorHandle betaTensor(tensorInfo);
70 armnn::ScopedCpuTensorHandle gammaTensor(tensorInfo);
71
72 AllocateAndCopyDataToITensorHandle(&meanTensor, &mean[0]);
73 AllocateAndCopyDataToITensorHandle(&varianceTensor, &variance[0]);
74 AllocateAndCopyDataToITensorHandle(&betaTensor, &beta[0]);
75 AllocateAndCopyDataToITensorHandle(&gammaTensor, &gamma[0]);
76
77 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
78 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
79 data.m_Mean = &meanTensor;
80 data.m_Variance = &varianceTensor;
81 data.m_Beta = &betaTensor;
82 data.m_Gamma = &gammaTensor;
83 data.m_Parameters.m_Eps = 0.0f;
84
telsoa01c577f2c2018-08-31 09:22:23 +010085 // For each channel:
86 // substract mean, divide by standard deviation (with an epsilon to avoid div by 0),
telsoa014fcda012018-03-09 14:13:49 +000087 // multiply by gamma and add beta
88 ret.outputExpected = MakeTensor<T, 4>(outputTensorInfo,
89 QuantizedVector<T>(qScale, qOffset,
90 {
91 1.f, 4.f,
92 4.f, 2.f,
93 1.f, 6.f,
94
95 3.f, 3.f,
96 4.f, 3.f,
97 2.f, 4.f
98 }));
99
100 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateBatchNormalization(data, info);
101
102 inputHandle->Allocate();
103 outputHandle->Allocate();
104
105 CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
106
107 workload->Execute();
108
109 CopyDataFromITensorHandle(&ret.output[0][0][0][0], outputHandle.get());
110
111 return ret;
112}