blob: 2700dd2c7bf5562a01de9418e80eb04390e3cd6d [file] [log] [blame]
surmeh013537c2c2018-05-18 16:31:43 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#pragma once
6
7#include "LayerWithParameters.hpp"
8
9namespace armnn
10{
11
12class PermuteLayer : public LayerWithParameters<PermuteDescriptor>
13{
14public:
15 virtual std::unique_ptr<IWorkload> CreateWorkload(const Graph& graph,
16 const IWorkloadFactory& factory) const override;
17
18 PermuteLayer* Clone(Graph& graph) const override;
19
20 void ValidateTensorShapesFromInputs() override;
telsoa01c577f2c2018-08-31 09:22:23 +010021 std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override;
surmeh013537c2c2018-05-18 16:31:43 +010022
23 const PermutationVector& GetPermutation() const
24 {
25 return m_Param.m_DimMappings;
26 }
27
28 bool IsInverse(const Layer& other) const
29 {
30 return (other.GetType() == LayerType::Permute) &&
31 GetPermutation().IsInverse(boost::polymorphic_downcast<const PermuteLayer*>(&other)->GetPermutation());
32 }
33
34 bool IsEqual(const Layer& other) const
35 {
36 return (other.GetType() == LayerType::Permute) &&
37 GetPermutation().IsEqual(boost::polymorphic_downcast<const PermuteLayer*>(&other)->GetPermutation());
38 }
39
40protected:
41 PermuteLayer(const PermuteDescriptor& param, const char* name);
42 ~PermuteLayer() = default;
43};
44
45} // namespace