blob: 24a0afacefa641f0c9494b48345880e219fccc2f [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#if (defined(__aarch64__)) || (defined(__x86_64__)) // disable test failing on FireFly/Armv7
7
Aron Virginas-Tar56055192018-11-12 18:10:43 +00008#include "ClWorkloadFactoryHelper.hpp"
9
Colm Donelanc42a9872022-02-02 16:35:09 +000010#include <armnnTestUtils/TensorHelpers.hpp>
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010011
Colm Donelan0c479742021-12-10 12:43:54 +000012#include <armnn/backends/TensorHandle.hpp>
13#include <armnn/backends/WorkloadFactory.hpp>
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010014
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000015#include <cl/ClContextControl.hpp>
16#include <cl/ClWorkloadFactory.hpp>
17#include <cl/OpenClTimer.hpp>
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010018
Sadik Armagana097d2a2021-11-24 15:47:28 +000019#include <armnnTestUtils/TensorCopyUtils.hpp>
Colm Donelan0c479742021-12-10 12:43:54 +000020#include <armnnTestUtils/WorkloadTestUtils.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010021
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010022#include <arm_compute/runtime/CL/CLScheduler.h>
23
Sadik Armagan1625efc2021-06-10 18:24:34 +010024#include <doctest/doctest.h>
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010025
26#include <iostream>
27
telsoa01c577f2c2018-08-31 09:22:23 +010028using namespace armnn;
29
30struct OpenClFixture
31{
32 // Initialising ClContextControl to ensure OpenCL is loaded correctly for each test case.
33 // NOTE: Profiling needs to be enabled in ClContextControl to be able to obtain execution
34 // times from OpenClTimer.
Finn Williams40646322021-02-11 16:16:42 +000035 OpenClFixture() : m_ClContextControl(nullptr, nullptr, true) {}
telsoa01c577f2c2018-08-31 09:22:23 +010036 ~OpenClFixture() {}
37
38 ClContextControl m_ClContextControl;
39};
40
Sadik Armagan1625efc2021-06-10 18:24:34 +010041TEST_CASE_FIXTURE(OpenClFixture, "OpenClTimerBatchNorm")
telsoa01c577f2c2018-08-31 09:22:23 +010042{
Sadik Armagan1625efc2021-06-10 18:24:34 +010043//using FactoryType = ClWorkloadFactory;
44
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +000045 auto memoryManager = ClWorkloadFactoryHelper::GetMemoryManager();
46 ClWorkloadFactory workloadFactory = ClWorkloadFactoryHelper::GetFactory(memoryManager);
telsoa01c577f2c2018-08-31 09:22:23 +010047
48 const unsigned int width = 2;
49 const unsigned int height = 3;
50 const unsigned int channels = 2;
51 const unsigned int num = 1;
telsoa01c577f2c2018-08-31 09:22:23 +010052
Aron Virginas-Tar48623a02019-10-22 10:00:28 +010053 TensorInfo inputTensorInfo( {num, channels, height, width}, DataType::Float32);
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000054 TensorInfo outputTensorInfo({num, channels, height, width}, DataType::Float32);
55 TensorInfo tensorInfo({channels}, DataType::Float32);
telsoa01c577f2c2018-08-31 09:22:23 +010056
Sadik Armagan483c8112021-06-01 09:24:52 +010057 std::vector<float> input =
58 {
59 1.f, 4.f,
60 4.f, 2.f,
61 1.f, 6.f,
telsoa01c577f2c2018-08-31 09:22:23 +010062
Sadik Armagan483c8112021-06-01 09:24:52 +010063 1.f, 1.f,
64 4.f, 1.f,
65 -2.f, 4.f
66 };
Aron Virginas-Tar48623a02019-10-22 10:00:28 +010067
telsoa01c577f2c2018-08-31 09:22:23 +010068 // these values are per-channel of the input
Sadik Armagan483c8112021-06-01 09:24:52 +010069 std::vector<float> mean = { 3.f, -2.f };
70 std::vector<float> variance = { 4.f, 9.f };
71 std::vector<float> beta = { 3.f, 2.f };
72 std::vector<float> gamma = { 2.f, 1.f };
telsoa01c577f2c2018-08-31 09:22:23 +010073
Teresa Charline2a3b3f2020-08-17 23:22:11 +010074 ARMNN_NO_DEPRECATE_WARN_BEGIN
telsoa01c577f2c2018-08-31 09:22:23 +010075 std::unique_ptr<ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
76 std::unique_ptr<ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
Teresa Charline2a3b3f2020-08-17 23:22:11 +010077 ARMNN_NO_DEPRECATE_WARN_END
telsoa01c577f2c2018-08-31 09:22:23 +010078
79 BatchNormalizationQueueDescriptor data;
80 WorkloadInfo info;
James Conroy1f58f032021-04-27 17:13:27 +010081 ScopedTensorHandle meanTensor(tensorInfo);
82 ScopedTensorHandle varianceTensor(tensorInfo);
83 ScopedTensorHandle betaTensor(tensorInfo);
84 ScopedTensorHandle gammaTensor(tensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +010085
Sadik Armagan483c8112021-06-01 09:24:52 +010086 AllocateAndCopyDataToITensorHandle(&meanTensor, mean.data());
87 AllocateAndCopyDataToITensorHandle(&varianceTensor, variance.data());
88 AllocateAndCopyDataToITensorHandle(&betaTensor, beta.data());
89 AllocateAndCopyDataToITensorHandle(&gammaTensor, gamma.data());
telsoa01c577f2c2018-08-31 09:22:23 +010090
91 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
92 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
93 data.m_Mean = &meanTensor;
94 data.m_Variance = &varianceTensor;
95 data.m_Beta = &betaTensor;
96 data.m_Gamma = &gammaTensor;
97 data.m_Parameters.m_Eps = 0.0f;
98
99 // for each channel:
100 // substract mean, divide by standard deviation (with an epsilon to avoid div by 0)
101 // multiply by gamma and add beta
Teresa Charlin611c7fb2022-01-07 09:47:29 +0000102 std::unique_ptr<IWorkload> workload = workloadFactory.CreateWorkload(LayerType::BatchNormalization, data, info);
telsoa01c577f2c2018-08-31 09:22:23 +0100103
104 inputHandle->Allocate();
105 outputHandle->Allocate();
106
Sadik Armagan483c8112021-06-01 09:24:52 +0100107 CopyDataToITensorHandle(inputHandle.get(), input.data());
telsoa01c577f2c2018-08-31 09:22:23 +0100108
109 OpenClTimer openClTimer;
110
Sadik Armagan1625efc2021-06-10 18:24:34 +0100111 CHECK_EQ(openClTimer.GetName(), "OpenClKernelTimer");
telsoa01c577f2c2018-08-31 09:22:23 +0100112
113 //Start the timer
114 openClTimer.Start();
115
116 //Execute the workload
117 workload->Execute();
118
119 //Stop the timer
120 openClTimer.Stop();
121
Sadik Armagan1625efc2021-06-10 18:24:34 +0100122 CHECK_EQ(openClTimer.GetMeasurements().size(), 1);
telsoa01c577f2c2018-08-31 09:22:23 +0100123
Sadik Armagan1625efc2021-06-10 18:24:34 +0100124 CHECK_EQ(openClTimer.GetMeasurements().front().m_Name,
telsoa01c577f2c2018-08-31 09:22:23 +0100125 "OpenClKernelTimer/0: batchnormalization_layer_nchw GWS[1,3,2]");
126
Sadik Armagan1625efc2021-06-10 18:24:34 +0100127 CHECK(openClTimer.GetMeasurements().front().m_Value > 0);
telsoa01c577f2c2018-08-31 09:22:23 +0100128
129}
130
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000131#endif //aarch64 or x86_64