blob: b842f1b4d5a13b01c91f00828a7432c940e6c5bf [file] [log] [blame]
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001//
Finn Williams87d0bda2020-07-03 10:12:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01003// SPDX-License-Identifier: MIT
4//
5#include "StackLayer.hpp"
6#include "LayerCloneBase.hpp"
7
8#include <armnn/TypesUtils.hpp>
Colm Donelan0c479742021-12-10 12:43:54 +00009#include <armnn/backends/WorkloadData.hpp>
10#include <armnn/backends/WorkloadFactory.hpp>
Matthew Jackson2b8c1da2019-07-04 14:59:16 +010011
12#include <queue>
13
14namespace armnn
15{
16
17StackLayer::StackLayer(const StackDescriptor& param, const char* name)
18 : LayerWithParameters(param.m_NumInputs, 1, LayerType::Stack, param, name)
19{
20}
21
Derek Lamberti94a88d22019-12-10 21:12:59 +000022std::unique_ptr<IWorkload> StackLayer::CreateWorkload(const IWorkloadFactory& factory) const
Matthew Jackson2b8c1da2019-07-04 14:59:16 +010023{
24 StackQueueDescriptor descriptor;
Keith Davisdf04d232020-10-23 17:20:05 +010025 SetAdditionalInfo(descriptor);
26
Teresa Charlin611c7fb2022-01-07 09:47:29 +000027 return factory.CreateWorkload(LayerType::Stack, descriptor, PrepInfoAndDesc(descriptor));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +010028}
29
30StackLayer* StackLayer::Clone(Graph& graph) const
31{
32 return CloneBase<StackLayer>(graph, m_Param, GetName());
33}
34
35std::vector<TensorShape> StackLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
36{
Jan Eilers8eb25602020-03-09 12:13:48 +000037 IgnoreUnused(inputShapes);
Derek Lamberti94a88d22019-12-10 21:12:59 +000038
Matthew Jackson2b8c1da2019-07-04 14:59:16 +010039 const TensorShape& inputShape = m_Param.m_InputShape;
40 const unsigned int inputNumDimensions = inputShape.GetNumDimensions();
41 const unsigned int axis = m_Param.m_Axis;
42
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010043 ARMNN_ASSERT(axis <= inputNumDimensions);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +010044
Rob Hughes91e1d892019-08-23 10:11:58 +010045 std::vector<unsigned int> dimensionSizes(inputNumDimensions + 1, 0);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +010046 for (unsigned int i = 0; i < axis; ++i)
47 {
48 dimensionSizes[i] = inputShape[i];
49 }
50
51 dimensionSizes[axis] = m_Param.m_NumInputs;
52
53 for (unsigned int i = axis + 1; i < inputNumDimensions + 1; ++i)
54 {
55 dimensionSizes[i] = inputShape[i-1];
56 }
57
Rob Hughes91e1d892019-08-23 10:11:58 +010058 TensorShape targetShape = TensorShape(inputNumDimensions + 1, dimensionSizes.data());
Matthew Jackson2b8c1da2019-07-04 14:59:16 +010059
60 return std::vector<TensorShape>({ targetShape });
61}
62
Finn Williamsf24effa2020-07-03 10:12:03 +010063void StackLayer::ValidateTensorShapesFromInputs()
Matthew Jackson2b8c1da2019-07-04 14:59:16 +010064{
65 // Validates Stack layer.
66 ConditionalThrowIfNotEqual<LayerValidationException>(
67 "StackLayer: Num Input Slots must match Num Inputs.",
68 m_Param.m_NumInputs,
69 GetNumInputSlots());
70
71 VerifyLayerConnections(m_Param.m_NumInputs, CHECK_LOCATION());
72
Finn Williams87d0bda2020-07-03 10:12:03 +010073 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
74
Finn Williamsf24effa2020-07-03 10:12:03 +010075 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
Finn Williams87d0bda2020-07-03 10:12:03 +010076
Matthew Jackson2b8c1da2019-07-04 14:59:16 +010077 // Constructs and validates input shapes
78 std::vector<TensorShape> inputShapes;
79 for (unsigned int i = 0; i < GetNumInputSlots(); ++i)
80 {
81 TensorShape inputShape = GetInputSlot(i).GetConnection()->GetTensorInfo().GetShape();
82 if (inputShape != m_Param.m_InputShape)
83 {
Matthew Jackson82b15ed2019-07-25 16:14:30 +010084 throw LayerValidationException("StackLayer: TensorShape set on InputSlot[" +
Matthew Jackson2b8c1da2019-07-04 14:59:16 +010085 std::to_string(i) +
86 "] does not match defined input shape");
87 }
88 inputShapes.push_back(inputShape);
89 }
90
91 auto inferredShapes = InferOutputShapes(inputShapes);
92
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010093 ARMNN_ASSERT(inferredShapes.size() == 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +010094
Finn Williamsf24effa2020-07-03 10:12:03 +010095 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "StackLayer");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +010096}
97
Jan Eilers1b2654f2021-09-24 15:45:46 +010098ARMNN_NO_DEPRECATE_WARN_BEGIN
Matthew Jackson2b8c1da2019-07-04 14:59:16 +010099void StackLayer::Accept(ILayerVisitor& visitor) const
100{
101 visitor.VisitStackLayer(this, GetParameters(), GetName());
102}
Jan Eilers1b2654f2021-09-24 15:45:46 +0100103ARMNN_NO_DEPRECATE_WARN_END
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100104
105} // namespace armnn armnn