Tianle Cheng | 2828818 | 2024-02-23 17:56:54 +0000 | [diff] [blame^] | 1 | // |
| 2 | // Copyright © 2024 Arm Ltd and Contributors. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include "ScatterNdLayer.hpp" |
| 7 | #include "LayerCloneBase.hpp" |
| 8 | |
| 9 | #include <armnn/TypesUtils.hpp> |
| 10 | #include <armnn/backends/WorkloadData.hpp> |
| 11 | #include <armnn/backends/WorkloadFactory.hpp> |
| 12 | |
| 13 | namespace armnn |
| 14 | { |
| 15 | |
| 16 | ScatterNdLayer::ScatterNdLayer(const ScatterNdDescriptor ¶m, const char* name) |
| 17 | : LayerWithParameters(3, 1, LayerType::ScatterNd, param, name) |
| 18 | { |
| 19 | } |
| 20 | |
| 21 | std::unique_ptr<IWorkload> ScatterNdLayer::CreateWorkload(const armnn::IWorkloadFactory& factory) const |
| 22 | { |
| 23 | ScatterNdQueueDescriptor descriptor; |
| 24 | SetAdditionalInfo(descriptor); |
| 25 | |
| 26 | return factory.CreateWorkload(LayerType::ScatterNd, descriptor, PrepInfoAndDesc(descriptor)); |
| 27 | } |
| 28 | |
| 29 | ScatterNdLayer* ScatterNdLayer::Clone(Graph& graph) const |
| 30 | { |
| 31 | auto layer = CloneBase<ScatterNdLayer>(graph, m_Param, GetName()); |
| 32 | |
| 33 | return std::move(layer); |
| 34 | } |
| 35 | |
| 36 | std::vector<TensorShape> ScatterNdLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const |
| 37 | { |
| 38 | const auto inputDims = inputShapes[0].GetNumDimensions(); |
| 39 | |
| 40 | std::vector<unsigned int> dimSizes(inputDims); |
| 41 | |
| 42 | for (unsigned i = 0; i < inputDims; ++i) |
| 43 | { |
| 44 | dimSizes[i] = inputShapes[0][i]; |
| 45 | } |
| 46 | |
| 47 | TensorShape outputShape({ inputDims, dimSizes.data() }); |
| 48 | |
| 49 | return std::vector<TensorShape>({ outputShape }); |
| 50 | } |
| 51 | |
| 52 | void ScatterNdLayer::ValidateTensorShapesFromInputs() |
| 53 | { |
| 54 | VerifyLayerConnections(3, CHECK_LOCATION()); |
| 55 | |
| 56 | const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape(); |
| 57 | |
| 58 | VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); |
| 59 | |
| 60 | if (m_Param.m_InputEnabled) |
| 61 | { |
| 62 | std::vector<TensorShape> inferredShapes = InferOutputShapes( |
| 63 | {GetInputSlot(0).GetTensorInfo().GetShape(), |
| 64 | GetInputSlot(1).GetTensorInfo().GetShape(), |
| 65 | GetInputSlot(2).GetTensorInfo().GetShape()}); |
| 66 | |
| 67 | if (inferredShapes.size() != 1) { |
| 68 | throw armnn::LayerValidationException("inferredShape has " + |
| 69 | std::to_string(inferredShapes.size()) + |
| 70 | " elements - should only have 1."); |
| 71 | } |
| 72 | |
| 73 | ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "ScatterLayer"); |
| 74 | } |
| 75 | else |
| 76 | { |
| 77 | // No input tensor, only shape provided via input slot |
| 78 | // In this case, we cannot validate the output shape from the input shape, but we can |
| 79 | // validate that the dimensions of shape and output tensor matched |
| 80 | unsigned int shapeDims = GetInputSlot(0).GetTensorInfo().GetNumDimensions(); |
| 81 | unsigned int outputDims = GetOutputSlot(0).GetTensorInfo().GetNumDimensions(); |
| 82 | |
| 83 | if (shapeDims != outputDims) |
| 84 | { |
| 85 | throw armnn::LayerValidationException("shape dimension " + |
| 86 | std::to_string(shapeDims) + |
| 87 | " and output dimension " + |
| 88 | std::to_string(outputDims) + |
| 89 | " are not matched."); |
| 90 | } |
| 91 | } |
| 92 | } |
| 93 | |
| 94 | } // namespace armnn |