blob: 16dd9a37441b325d02e43f01ca7484de5ed7601b [file] [log] [blame]
mathad01b392e982021-04-07 12:07:30 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include "CastLayer.hpp"
6
7#include "LayerCloneBase.hpp"
8#include <armnn/TypesUtils.hpp>
9
10#include <backendsCommon/WorkloadData.hpp>
11#include <backendsCommon/WorkloadFactory.hpp>
12
13namespace armnn
14{
15
16CastLayer::CastLayer(const char* name)
17 : Layer(1, 1, LayerType::Cast, name)
18{
19}
20
21std::unique_ptr<IWorkload> CastLayer::CreateWorkload(const IWorkloadFactory& factory) const
22{
23 CastQueueDescriptor descriptor;
24 SetAdditionalInfo(descriptor);
25
26 return factory.CreateCast(descriptor, PrepInfoAndDesc(descriptor));
27}
28
29CastLayer* CastLayer::Clone(Graph& graph) const
30{
31 return CloneBase<CastLayer>(graph, GetName());
32}
33
34void CastLayer::ValidateTensorShapesFromInputs()
35{
36 VerifyLayerConnections(1, CHECK_LOCATION());
37
38 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
39
40 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
41
42 auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() });
43
44 ARMNN_ASSERT(inferredShapes.size() == 1);
45
46 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "CastLayer");
47}
48
49void CastLayer::Accept(ILayerVisitor& visitor) const
50{
51 IgnoreUnused(visitor);
52 throw armnn::Exception("CastLayer VisitCastLayer is not implemented");
53}
54
55} // namespace armnn