blob: d6c30817b849064948817b10e3483468fc6c54c7 [file] [log] [blame]
Sadik Armagan0d4863d2019-10-09 14:26:32 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "NeonInstanceNormalizationWorkload.hpp"
7
8#include "NeonWorkloadUtils.hpp"
9
10#include <aclCommon/ArmComputeTensorUtils.hpp>
11#include <backendsCommon/CpuTensorHandle.hpp>
12#include <neon/NeonTensorHandle.hpp>
13
14using namespace armnn::armcomputetensorutils;
15
16namespace armnn
17{
18
19arm_compute::Status NeonInstanceNormalizationWorkloadValidate(const TensorInfo& input,
20 const TensorInfo& output,
21 const InstanceNormalizationDescriptor& descriptor)
22{
23 const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
24 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
25
26 return arm_compute::NEInstanceNormalizationLayer::validate(&aclInputInfo,
27 &aclOutputInfo,
28 descriptor.m_Gamma,
29 descriptor.m_Beta,
30 descriptor.m_Eps);
31}
32
33NeonInstanceNormalizationWorkload::NeonInstanceNormalizationWorkload(
34 const InstanceNormalizationQueueDescriptor& descriptor,
35 const WorkloadInfo& info)
36 : BaseWorkload<InstanceNormalizationQueueDescriptor>(descriptor, info)
37{
38 m_Data.ValidateInputsOutputs("NeonInstanceNormalizationWorkload", 1, 1);
39
40 arm_compute::ITensor& input = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
41 arm_compute::ITensor& output = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
42
43 arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout);
44 input.info()->set_data_layout(aclDataLayout);
45 output.info()->set_data_layout(aclDataLayout);
46
47 m_Layer.configure(&input,
48 &output,
49 descriptor.m_Parameters.m_Gamma,
50 descriptor.m_Parameters.m_Beta,
51 descriptor.m_Parameters.m_Eps);
52};
53
54void NeonInstanceNormalizationWorkload::Execute() const
55{
56 ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonInstanceNormalizationWorkload_Execute");
57 m_Layer.run();
58}
59
60} // namespace armnn