blob: 4a2945efbe900a100c08c46bfb6c5fef175b2f96 [file] [log] [blame]
surmeh013537c2c2018-05-18 16:31:43 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh013537c2c2018-05-18 16:31:43 +01004//
5#include "L2NormalizationLayer.hpp"
6
7#include "LayerCloneBase.hpp"
8
9#include <armnn/TypesUtils.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <backendsCommon/WorkloadData.hpp>
11#include <backendsCommon/WorkloadFactory.hpp>
surmeh013537c2c2018-05-18 16:31:43 +010012
13namespace armnn
14{
15
Matteo Martincighbcd3c852018-09-28 14:14:12 +010016L2NormalizationLayer::L2NormalizationLayer(const L2NormalizationDescriptor& param, const char* name)
17 : LayerWithParameters(1, 1, LayerType::L2Normalization, param, name)
surmeh013537c2c2018-05-18 16:31:43 +010018{
19}
20
Derek Lamberti94a88d22019-12-10 21:12:59 +000021std::unique_ptr<IWorkload> L2NormalizationLayer::CreateWorkload(const IWorkloadFactory& factory) const
surmeh013537c2c2018-05-18 16:31:43 +010022{
23 L2NormalizationQueueDescriptor descriptor;
Derek Lamberti94a88d22019-12-10 21:12:59 +000024 return factory.CreateL2Normalization(descriptor, PrepInfoAndDesc(descriptor));
surmeh013537c2c2018-05-18 16:31:43 +010025}
26
27L2NormalizationLayer* L2NormalizationLayer::Clone(Graph& graph) const
28{
Matteo Martincighbcd3c852018-09-28 14:14:12 +010029 return CloneBase<L2NormalizationLayer>(graph, m_Param, GetName());
surmeh013537c2c2018-05-18 16:31:43 +010030}
31
Teresa Charlincdc01492020-06-09 18:00:20 +010032void L2NormalizationLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInferenceMethod)
surmeh013537c2c2018-05-18 16:31:43 +010033{
Teresa Charlincdc01492020-06-09 18:00:20 +010034 IgnoreUnused(shapeInferenceMethod);
35
telsoa01c577f2c2018-08-31 09:22:23 +010036 VerifyLayerConnections(1, CHECK_LOCATION());
surmeh013537c2c2018-05-18 16:31:43 +010037
telsoa01c577f2c2018-08-31 09:22:23 +010038 auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() });
surmeh013537c2c2018-05-18 16:31:43 +010039
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010040 ARMNN_ASSERT(inferredShapes.size() == 1);
telsoa01c577f2c2018-08-31 09:22:23 +010041
surmeh013537c2c2018-05-18 16:31:43 +010042 ConditionalThrowIfNotEqual<LayerValidationException>(
43 "L2NormalizationLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
44 GetOutputSlot(0).GetTensorInfo().GetShape(),
telsoa01c577f2c2018-08-31 09:22:23 +010045 inferredShapes[0]);
surmeh013537c2c2018-05-18 16:31:43 +010046}
47
jimfly01e9e7bfd2019-01-24 22:29:33 +000048void L2NormalizationLayer::Accept(ILayerVisitor& visitor) const
49{
50 visitor.VisitL2NormalizationLayer(this, GetParameters(), GetName());
51}
52
surmeh013537c2c2018-05-18 16:31:43 +010053} // namespace armnn