blob: 8dff6ba5d515289edb46c52cf4cb1e3c931d6d20 [file] [log] [blame]
mathad01b392e982021-04-07 12:07:30 +01001//
Declan-ARM7c75e332024-03-12 16:40:25 +00002// Copyright © 2021-2024 Arm Ltd and Contributors. All rights reserved.
mathad01b392e982021-04-07 12:07:30 +01003// SPDX-License-Identifier: MIT
4//
5#include "CastLayer.hpp"
6
7#include "LayerCloneBase.hpp"
8#include <armnn/TypesUtils.hpp>
9
Colm Donelan0c479742021-12-10 12:43:54 +000010#include <armnn/backends/WorkloadData.hpp>
11#include <armnn/backends/WorkloadFactory.hpp>
mathad01b392e982021-04-07 12:07:30 +010012
13namespace armnn
14{
15
16CastLayer::CastLayer(const char* name)
17 : Layer(1, 1, LayerType::Cast, name)
18{
19}
20
21std::unique_ptr<IWorkload> CastLayer::CreateWorkload(const IWorkloadFactory& factory) const
22{
23 CastQueueDescriptor descriptor;
24 SetAdditionalInfo(descriptor);
25
Teresa Charlin611c7fb2022-01-07 09:47:29 +000026 return factory.CreateWorkload(LayerType::Cast, descriptor, PrepInfoAndDesc(descriptor));
mathad01b392e982021-04-07 12:07:30 +010027}
28
29CastLayer* CastLayer::Clone(Graph& graph) const
30{
31 return CloneBase<CastLayer>(graph, GetName());
32}
33
34void CastLayer::ValidateTensorShapesFromInputs()
35{
36 VerifyLayerConnections(1, CHECK_LOCATION());
37
38 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
39
40 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
41
Mike Kellya9ac6ba2023-06-30 15:18:26 +010042 auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetTensorInfo().GetShape() });
mathad01b392e982021-04-07 12:07:30 +010043
Declan-ARM7c75e332024-03-12 16:40:25 +000044 if (inferredShapes.size() != 1)
45 {
46 throw armnn::LayerValidationException("inferredShapes has "
47 + std::to_string(inferredShapes.size()) +
48 " elements - should only have 1.");
49 }
mathad01b392e982021-04-07 12:07:30 +010050
51 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "CastLayer");
52}
53
mathad01b392e982021-04-07 12:07:30 +010054} // namespace armnn