blob: d551221ae1a55310821c3381fdf1904b937c3b38 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include <armnn/ArmNN.hpp>
8#include <armnn/Tensor.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009
David Beckac42efd2018-09-26 17:41:13 +010010#include <test/TensorHelpers.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011
David Beckac42efd2018-09-26 17:41:13 +010012#include <backends/CpuTensorHandle.hpp>
13#include <backends/WorkloadFactory.hpp>
telsoa014fcda012018-03-09 14:13:49 +000014
David Beckac42efd2018-09-26 17:41:13 +010015#include <backends/test/QuantizeHelper.hpp>
telsoa014fcda012018-03-09 14:13:49 +000016
17
18template<typename T>
19LayerTestResult<T,4> BatchNormTestImpl(armnn::IWorkloadFactory& workloadFactory,
20 float qScale,
21 int32_t qOffset)
22{
23 const unsigned int width = 2;
24 const unsigned int height = 3;
25 const unsigned int channels = 2;
26 const unsigned int num = 1;
27
28 armnn::TensorInfo inputTensorInfo({num, channels, height, width}, armnn::GetDataType<T>());
29 armnn::TensorInfo outputTensorInfo({num, channels, height, width}, armnn::GetDataType<T>());
30 armnn::TensorInfo tensorInfo({channels}, armnn::GetDataType<T>());
31
32 // Set quantization parameters if the requested type is a quantized type.
33 if(armnn::IsQuantizedType<T>())
34 {
35 inputTensorInfo.SetQuantizationScale(qScale);
36 inputTensorInfo.SetQuantizationOffset(qOffset);
37 outputTensorInfo.SetQuantizationScale(qScale);
38 outputTensorInfo.SetQuantizationOffset(qOffset);
39 tensorInfo.SetQuantizationScale(qScale);
40 tensorInfo.SetQuantizationOffset(qOffset);
41 }
42
43 auto input = MakeTensor<T, 4>(inputTensorInfo,
44 QuantizedVector<T>(qScale, qOffset,
45 {
46 1.f, 4.f,
47 4.f, 2.f,
48 1.f, 6.f,
49
50 1.f, 1.f,
51 4.f, 1.f,
52 -2.f, 4.f
53 }));
telsoa01c577f2c2018-08-31 09:22:23 +010054 // These values are per-channel of the input.
telsoa014fcda012018-03-09 14:13:49 +000055 auto mean = MakeTensor<T, 1>(tensorInfo, QuantizedVector<T>(qScale, qOffset, {3, -2}));
56 auto variance = MakeTensor<T, 1>(tensorInfo, QuantizedVector<T>(qScale, qOffset, {4, 9}));
57 auto beta = MakeTensor<T, 1>(tensorInfo, QuantizedVector<T>(qScale, qOffset, {3, 2}));
58 auto gamma = MakeTensor<T, 1>(tensorInfo, QuantizedVector<T>(qScale, qOffset, {2, 1}));
59 LayerTestResult<T,4> ret(outputTensorInfo);
60
61 std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
62 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
63
64 armnn::BatchNormalizationQueueDescriptor data;
65 armnn::WorkloadInfo info;
66 armnn::ScopedCpuTensorHandle meanTensor(tensorInfo);
67 armnn::ScopedCpuTensorHandle varianceTensor(tensorInfo);
68 armnn::ScopedCpuTensorHandle betaTensor(tensorInfo);
69 armnn::ScopedCpuTensorHandle gammaTensor(tensorInfo);
70
71 AllocateAndCopyDataToITensorHandle(&meanTensor, &mean[0]);
72 AllocateAndCopyDataToITensorHandle(&varianceTensor, &variance[0]);
73 AllocateAndCopyDataToITensorHandle(&betaTensor, &beta[0]);
74 AllocateAndCopyDataToITensorHandle(&gammaTensor, &gamma[0]);
75
76 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
77 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
78 data.m_Mean = &meanTensor;
79 data.m_Variance = &varianceTensor;
80 data.m_Beta = &betaTensor;
81 data.m_Gamma = &gammaTensor;
82 data.m_Parameters.m_Eps = 0.0f;
83
telsoa01c577f2c2018-08-31 09:22:23 +010084 // For each channel:
85 // substract mean, divide by standard deviation (with an epsilon to avoid div by 0),
telsoa014fcda012018-03-09 14:13:49 +000086 // multiply by gamma and add beta
87 ret.outputExpected = MakeTensor<T, 4>(outputTensorInfo,
88 QuantizedVector<T>(qScale, qOffset,
89 {
90 1.f, 4.f,
91 4.f, 2.f,
92 1.f, 6.f,
93
94 3.f, 3.f,
95 4.f, 3.f,
96 2.f, 4.f
97 }));
98
99 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateBatchNormalization(data, info);
100
101 inputHandle->Allocate();
102 outputHandle->Allocate();
103
104 CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
105
106 workload->Execute();
107
108 CopyDataFromITensorHandle(&ret.output[0][0][0][0], outputHandle.get());
109
110 return ret;
111}