blob: 84d25bf7560397cdd58624f4a8627c36bc356bb3 [file] [log] [blame]
Finn Williams2605b232020-06-10 15:53:46 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RankLayer.hpp"
7
8#include "LayerCloneBase.hpp"
9
Colm Donelan0c479742021-12-10 12:43:54 +000010#include <armnn/backends/WorkloadData.hpp>
11#include <armnn/backends/WorkloadFactory.hpp>
Finn Williams2605b232020-06-10 15:53:46 +010012
13namespace armnn
14{
15
16RankLayer::RankLayer(const char* name)
17 : Layer(1, 1, LayerType::Rank, name)
18{}
19
20std::unique_ptr<IWorkload> RankLayer::CreateWorkload(const IWorkloadFactory& factory) const
21{
22 RankQueueDescriptor descriptor;
Keith Davisdf04d232020-10-23 17:20:05 +010023 SetAdditionalInfo(descriptor);
24
Teresa Charlin611c7fb2022-01-07 09:47:29 +000025 return factory.CreateWorkload(LayerType::Rank, descriptor, PrepInfoAndDesc(descriptor));
Finn Williams2605b232020-06-10 15:53:46 +010026}
27
28Layer* RankLayer::Clone(Graph& graph) const
29{
30 RankLayer* clone = CloneBase<RankLayer>(graph, GetName());
31 return clone;
32}
33
Finn Williamsf24effa2020-07-03 10:12:03 +010034void RankLayer::ValidateTensorShapesFromInputs()
Finn Williams2605b232020-06-10 15:53:46 +010035{
Finn Williams2605b232020-06-10 15:53:46 +010036 VerifyLayerConnections(1, CHECK_LOCATION());
37
Finn Williams87d0bda2020-07-03 10:12:03 +010038 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
39 const TensorShape inferredShape = TensorShape(Dimensionality::Scalar);
Finn Williams2605b232020-06-10 15:53:46 +010040
Finn Williamsf24effa2020-07-03 10:12:03 +010041 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
42 ValidateAndCopyShape(outputShape, inferredShape, m_ShapeInferenceMethod, "RankLayer");
Finn Williams87d0bda2020-07-03 10:12:03 +010043}
Jan Eilers1b2654f2021-09-24 15:45:46 +010044
45ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williams2605b232020-06-10 15:53:46 +010046void RankLayer::Accept(ILayerVisitor& visitor) const
47{
48 visitor.VisitRankLayer(this, GetName());
49}
Jan Eilers1b2654f2021-09-24 15:45:46 +010050ARMNN_NO_DEPRECATE_WARN_END
Finn Williams2605b232020-06-10 15:53:46 +010051
Finn Williamsb454c5c2021-02-09 15:56:23 +000052void RankLayer::ExecuteStrategy(IStrategy& strategy) const
53{
54 strategy.ExecuteStrategy(this, BaseDescriptor(), {}, GetName());
55}
56
Finn Williams2605b232020-06-10 15:53:46 +010057} //namespace armnn