blob: aee95d063c466b30bc52293df8561d94591a4530 [file] [log] [blame]
surmeh013537c2c2018-05-18 16:31:43 +01001//
Finn Williams87d0bda2020-07-03 10:12:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh013537c2c2018-05-18 16:31:43 +01004//
5#include "ConstantLayer.hpp"
6#include "LayerCloneBase.hpp"
7
8#include <armnn/TypesUtils.hpp>
Colm Donelan0c479742021-12-10 12:43:54 +00009#include <armnn/backends/TensorHandle.hpp>
10#include <armnn/backends/WorkloadData.hpp>
11#include <armnn/backends/WorkloadFactory.hpp>
surmeh013537c2c2018-05-18 16:31:43 +010012
13namespace armnn
14{
15
telsoa01c577f2c2018-08-31 09:22:23 +010016ConstantLayer::ConstantLayer(const char* name)
surmeh013537c2c2018-05-18 16:31:43 +010017 : Layer(0, 1, LayerType::Constant, name)
surmeh013537c2c2018-05-18 16:31:43 +010018{
19}
20
Derek Lamberti94a88d22019-12-10 21:12:59 +000021std::unique_ptr<IWorkload> ConstantLayer::CreateWorkload(const IWorkloadFactory& factory) const
surmeh013537c2c2018-05-18 16:31:43 +010022{
23 ConstantQueueDescriptor descriptor;
24 descriptor.m_LayerOutput = m_LayerOutput.get();
Keith Davisdf04d232020-10-23 17:20:05 +010025 SetAdditionalInfo(descriptor);
26
Teresa Charlin611c7fb2022-01-07 09:47:29 +000027 return factory.CreateWorkload(LayerType::Constant, descriptor, PrepInfoAndDesc(descriptor));
surmeh013537c2c2018-05-18 16:31:43 +010028}
29
30ConstantLayer* ConstantLayer::Clone(Graph& graph) const
31{
telsoa01c577f2c2018-08-31 09:22:23 +010032 // Cloned layers share the same layer output object.
33 auto layer = CloneBase<ConstantLayer>(graph, GetName());
34
Finn Williams4422cec2021-03-22 17:51:06 +000035 layer->m_LayerOutput = m_LayerOutput ? m_LayerOutput : nullptr;
telsoa01c577f2c2018-08-31 09:22:23 +010036
37 return std::move(layer);
38}
39
40std::vector<TensorShape> ConstantLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
41{
Derek Lamberti94a88d22019-12-10 21:12:59 +000042 return std::vector<TensorShape>({ inputShapes[0] });
surmeh013537c2c2018-05-18 16:31:43 +010043}
44
Finn Williamsf24effa2020-07-03 10:12:03 +010045void ConstantLayer::ValidateTensorShapesFromInputs()
surmeh013537c2c2018-05-18 16:31:43 +010046{
Teresa Charlincdc01492020-06-09 18:00:20 +010047
telsoa01c577f2c2018-08-31 09:22:23 +010048 // Get the output shape from the value of the constant layer.
surmeh013537c2c2018-05-18 16:31:43 +010049 TensorShape const& outShape = m_LayerOutput->GetTensorInfo().GetShape();
Finn Williams87d0bda2020-07-03 10:12:03 +010050
51 ConditionalThrow<LayerValidationException>(
52 outShape.GetDimensionality() != Dimensionality::NotSpecified,
53 "Constant layer m_LayerOutput output shape can not be Dimensionality::NotSpecified");
54
55 ConditionalThrow<LayerValidationException>(
56 outShape.AreAllDimensionsSpecified(),
57 "Constant layer m_LayerOutput output shape can not have an unspecified dimension");
58
surmeh013537c2c2018-05-18 16:31:43 +010059 ConditionalThrowIfNotEqual<LayerValidationException>(
Finn Williams87d0bda2020-07-03 10:12:03 +010060 "ConstantLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
61 GetOutputSlot(0).GetTensorInfo().GetShape(),
62 outShape);
surmeh013537c2c2018-05-18 16:31:43 +010063}
64
Finn Williamsb454c5c2021-02-09 15:56:23 +000065void ConstantLayer::ExecuteStrategy(IStrategy& strategy) const
66{
Francis Murtagh4af56162021-04-20 16:37:55 +010067 ManagedConstTensorHandle managedLayerOutput(m_LayerOutput);
68 ConstTensor layerOutputTensor(managedLayerOutput.GetTensorInfo(), managedLayerOutput.Map());
69 strategy.ExecuteStrategy(this, BaseDescriptor(), { layerOutputTensor }, GetName());
Finn Williamsb454c5c2021-02-09 15:56:23 +000070}
71
surmeh013537c2c2018-05-18 16:31:43 +010072} // namespace armnn