blob: 14a5f90d067a34bbd1656d6b07ead362e64812f6 [file] [log] [blame]
surmeh013537c2c2018-05-18 16:31:43 +01001//
Declan-ARM7c75e332024-03-12 16:40:25 +00002// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh013537c2c2018-05-18 16:31:43 +01004//
5#include "L2NormalizationLayer.hpp"
6
7#include "LayerCloneBase.hpp"
8
9#include <armnn/TypesUtils.hpp>
Colm Donelan0c479742021-12-10 12:43:54 +000010#include <armnn/backends/WorkloadData.hpp>
11#include <armnn/backends/WorkloadFactory.hpp>
surmeh013537c2c2018-05-18 16:31:43 +010012
13namespace armnn
14{
15
Matteo Martincighbcd3c852018-09-28 14:14:12 +010016L2NormalizationLayer::L2NormalizationLayer(const L2NormalizationDescriptor& param, const char* name)
17 : LayerWithParameters(1, 1, LayerType::L2Normalization, param, name)
surmeh013537c2c2018-05-18 16:31:43 +010018{
19}
20
Derek Lamberti94a88d22019-12-10 21:12:59 +000021std::unique_ptr<IWorkload> L2NormalizationLayer::CreateWorkload(const IWorkloadFactory& factory) const
surmeh013537c2c2018-05-18 16:31:43 +010022{
23 L2NormalizationQueueDescriptor descriptor;
Keith Davisdf04d232020-10-23 17:20:05 +010024 SetAdditionalInfo(descriptor);
25
Teresa Charlin611c7fb2022-01-07 09:47:29 +000026 return factory.CreateWorkload(LayerType::L2Normalization, descriptor, PrepInfoAndDesc(descriptor));
surmeh013537c2c2018-05-18 16:31:43 +010027}
28
29L2NormalizationLayer* L2NormalizationLayer::Clone(Graph& graph) const
30{
Matteo Martincighbcd3c852018-09-28 14:14:12 +010031 return CloneBase<L2NormalizationLayer>(graph, m_Param, GetName());
surmeh013537c2c2018-05-18 16:31:43 +010032}
33
Finn Williamsf24effa2020-07-03 10:12:03 +010034void L2NormalizationLayer::ValidateTensorShapesFromInputs()
surmeh013537c2c2018-05-18 16:31:43 +010035{
telsoa01c577f2c2018-08-31 09:22:23 +010036 VerifyLayerConnections(1, CHECK_LOCATION());
surmeh013537c2c2018-05-18 16:31:43 +010037
Finn Williams87d0bda2020-07-03 10:12:03 +010038 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
39
Finn Williamsf24effa2020-07-03 10:12:03 +010040 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
Finn Williams87d0bda2020-07-03 10:12:03 +010041
Mike Kellya9ac6ba2023-06-30 15:18:26 +010042 auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetTensorInfo().GetShape() });
surmeh013537c2c2018-05-18 16:31:43 +010043
Declan-ARM7c75e332024-03-12 16:40:25 +000044 if (inferredShapes.size() != 1)
45 {
46 throw armnn::LayerValidationException("inferredShapes has "
47 + std::to_string(inferredShapes.size()) +
48 " elements - should only have 1.");
49 }
telsoa01c577f2c2018-08-31 09:22:23 +010050
Finn Williamsf24effa2020-07-03 10:12:03 +010051 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "L2NormalizationLayer");
surmeh013537c2c2018-05-18 16:31:43 +010052}
53
Nikhil Raj4d2eec02022-05-30 11:08:52 +010054void L2NormalizationLayer::ExecuteStrategy(IStrategy& strategy) const
jimfly01e9e7bfd2019-01-24 22:29:33 +000055{
Nikhil Raj4d2eec02022-05-30 11:08:52 +010056 strategy.ExecuteStrategy(this, GetParameters(), {}, GetName());
jimfly01e9e7bfd2019-01-24 22:29:33 +000057}
58
surmeh013537c2c2018-05-18 16:31:43 +010059} // namespace armnn