blob: 38396c1a9c27cc106b618530514becff9fba0742 [file] [log] [blame]
//
// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include "Optimization.hpp"
namespace armnn
{
namespace optimizations
{
class DeleteBroadcastToImpl
{
public:
/// Run for every BroadcastToLayer. Remove it if it is before an ElementWiseLayer.
/// Since ElementWiseBinary uses a brodcastLoop, using a broadcastTo layer is
/// not necessary so it will be deleted.
void Run(Graph&, BroadcastToLayer& layer) const
{
if(layer.GetType() == LayerType::BroadcastTo)
{
TensorInfo info = layer.GetOutputSlot(0).GetTensorInfo();
Layer& next = layer.GetOutputSlot(0).GetConnection(0)->GetOwningLayer();
if (next.GetType() == LayerType::ElementwiseBinary)
{
Layer& connectedLayer = layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetOwningLayer();
auto tensorInfo = connectedLayer.GetOutputSlot().GetTensorInfo();
layer.GetOutputSlot().MoveAllConnections(connectedLayer.GetOutputSlot());
connectedLayer.GetOutputSlot().GetOutputHandler().SetTensorInfo(tensorInfo);
}
}
}
protected:
DeleteBroadcastToImpl() = default;
~DeleteBroadcastToImpl() = default;
};
using BroadcastToOptimizationLayer = OptimizeForType<BroadcastToLayer, DeleteBroadcastToImpl>;
}
}