blob: a0b270fba55eedbdd5d8eb6798bb139c7864bd75 [file] [log] [blame]
Tianle Cheng28288182024-02-23 17:56:54 +00001//
2// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ScatterNdLayer.hpp"
7#include "LayerCloneBase.hpp"
8
9#include <armnn/TypesUtils.hpp>
10#include <armnn/backends/WorkloadData.hpp>
11#include <armnn/backends/WorkloadFactory.hpp>
12
13namespace armnn
14{
15
16ScatterNdLayer::ScatterNdLayer(const ScatterNdDescriptor &param, const char* name)
17 : LayerWithParameters(3, 1, LayerType::ScatterNd, param, name)
18{
19}
20
21std::unique_ptr<IWorkload> ScatterNdLayer::CreateWorkload(const armnn::IWorkloadFactory& factory) const
22{
23 ScatterNdQueueDescriptor descriptor;
24 SetAdditionalInfo(descriptor);
25
26 return factory.CreateWorkload(LayerType::ScatterNd, descriptor, PrepInfoAndDesc(descriptor));
27}
28
29ScatterNdLayer* ScatterNdLayer::Clone(Graph& graph) const
30{
31 auto layer = CloneBase<ScatterNdLayer>(graph, m_Param, GetName());
32
33 return std::move(layer);
34}
35
36std::vector<TensorShape> ScatterNdLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
37{
38 const auto inputDims = inputShapes[0].GetNumDimensions();
39
40 std::vector<unsigned int> dimSizes(inputDims);
41
42 for (unsigned i = 0; i < inputDims; ++i)
43 {
44 dimSizes[i] = inputShapes[0][i];
45 }
46
47 TensorShape outputShape({ inputDims, dimSizes.data() });
48
49 return std::vector<TensorShape>({ outputShape });
50}
51
52void ScatterNdLayer::ValidateTensorShapesFromInputs()
53{
54 VerifyLayerConnections(3, CHECK_LOCATION());
55
56 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
57
58 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
59
60 if (m_Param.m_InputEnabled)
61 {
62 std::vector<TensorShape> inferredShapes = InferOutputShapes(
63 {GetInputSlot(0).GetTensorInfo().GetShape(),
64 GetInputSlot(1).GetTensorInfo().GetShape(),
65 GetInputSlot(2).GetTensorInfo().GetShape()});
66
67 if (inferredShapes.size() != 1) {
68 throw armnn::LayerValidationException("inferredShape has " +
69 std::to_string(inferredShapes.size()) +
70 " elements - should only have 1.");
71 }
72
73 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "ScatterLayer");
74 }
75 else
76 {
77 // No input tensor, only shape provided via input slot
78 // In this case, we cannot validate the output shape from the input shape, but we can
79 // validate that the dimensions of shape and output tensor matched
80 unsigned int shapeDims = GetInputSlot(0).GetTensorInfo().GetNumDimensions();
81 unsigned int outputDims = GetOutputSlot(0).GetTensorInfo().GetNumDimensions();
82
83 if (shapeDims != outputDims)
84 {
85 throw armnn::LayerValidationException("shape dimension " +
86 std::to_string(shapeDims) +
87 " and output dimension " +
88 std::to_string(outputDims) +
89 " are not matched.");
90 }
91 }
92}
93
94} // namespace armnn