blob: 61e68633042fd314f7dae3f44f2ebf6bd8b9a209 [file] [log] [blame]
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001//
Finn Williams87d0bda2020-07-03 10:12:03 +01002// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
Mike Kellyc9ea45a2020-02-28 18:11:58 +00003// SPDX-License-Identifier: MIT
4//
5
6#include "TransposeLayer.hpp"
7
8#include "LayerCloneBase.hpp"
9
10#include <armnn/TypesUtils.hpp>
11
12#include <armnnUtils/Transpose.hpp>
13
14#include <backendsCommon/WorkloadData.hpp>
15#include <backendsCommon/WorkloadFactory.hpp>
16
17namespace armnn
18{
19
20TransposeLayer::TransposeLayer(const TransposeDescriptor& param, const char* name)
21 : LayerWithParameters(1, 1, LayerType::Transpose, param, name)
22{
23}
24
25std::unique_ptr<IWorkload> TransposeLayer::CreateWorkload(const IWorkloadFactory& factory) const
26{
27 TransposeQueueDescriptor descriptor;
28 return factory.CreateTranspose(descriptor, PrepInfoAndDesc(descriptor));
29}
30
31TransposeLayer* TransposeLayer::Clone(Graph& graph) const
32{
33 return CloneBase<TransposeLayer>(graph, m_Param, GetName());
34}
35
36std::vector<TensorShape> TransposeLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
37{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010038 ARMNN_ASSERT(inputShapes.size() == 1);
Mike Kellyc9ea45a2020-02-28 18:11:58 +000039 const TensorShape& inShape = inputShapes[0];
40 return std::vector<TensorShape> ({armnnUtils::TransposeTensorShape(inShape, m_Param.m_DimMappings)});
41}
42
Finn Williamsf24effa2020-07-03 10:12:03 +010043void TransposeLayer::ValidateTensorShapesFromInputs()
Mike Kellyc9ea45a2020-02-28 18:11:58 +000044{
45 VerifyLayerConnections(1, CHECK_LOCATION());
46
Finn Williams87d0bda2020-07-03 10:12:03 +010047 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
48
Finn Williamsf24effa2020-07-03 10:12:03 +010049 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
Finn Williams87d0bda2020-07-03 10:12:03 +010050
Mike Kellyc9ea45a2020-02-28 18:11:58 +000051 auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() });
52
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010053 ARMNN_ASSERT(inferredShapes.size() == 1);
Mike Kellyc9ea45a2020-02-28 18:11:58 +000054
Finn Williamsf24effa2020-07-03 10:12:03 +010055 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "TransposeLayer");
Mike Kellyc9ea45a2020-02-28 18:11:58 +000056}
57
58void TransposeLayer::Accept(ILayerVisitor& visitor) const
59{
60 visitor.VisitTransposeLayer(this, GetParameters(), GetName());
61}
62
63} // namespace armnn