| // |
| // Copyright © 2017 Arm Ltd. All rights reserved. |
| // See LICENSE file in the project root for full license information. |
| // |
| #pragma once |
| |
| #include "Optimization.hpp" |
| #include "NetworkUtils.hpp" |
| |
| namespace armnn |
| { |
| namespace optimizations |
| { |
| |
| class ConvertFp32NetworkToFp16Impl |
| { |
| public: |
| |
| void Run(Graph& graph, Layer& layer) const |
| { |
| if(layer.GetType() == LayerType::Input) |
| { |
| // if the outputs of this layer are DataType::Float32 |
| // add a ConvertFloat32ToFloat16 layer after each of the outputs |
| if (layer.GetDataType() == DataType::Float32) |
| { |
| InsertConvertFp32ToFp16LayersAfter(graph, layer); |
| } |
| } |
| else if (layer.GetType() == LayerType::Output) |
| { |
| // if the inputs of this layer are DataType::Float32 |
| // add a ConvertFloat16ToFloat32 layer before each of the inputs |
| if (layer.GetDataType() == DataType::Float32) |
| { |
| InsertConvertFp16ToFp32LayersBefore(graph, layer); |
| } |
| } |
| else if (layer.GetType() != LayerType::ConvertFp32ToFp16 && layer.GetType() != LayerType::ConvertFp16ToFp32) |
| { |
| // if the inputs/outputs of this layer are DataType::Float32 |
| // change the data type for all inputs and outputs to DataType::Float16 |
| for (auto&& input = layer.BeginInputSlots(); input != layer.EndInputSlots(); ++input) |
| { |
| // if it is connected to OutputSlot of the InputLayer do not change the DataType of connection |
| // InputSlots of the current layer will be updated when conversion layer is inserted after InputLayer |
| Layer& base = input->GetConnectedOutputSlot()->GetOwningLayer(); |
| if (base.GetType() != LayerType::Input) |
| { |
| TensorInfo convertInfo = input->GetConnection()->GetTensorInfo(); |
| if (convertInfo.GetDataType() == DataType::Float32) |
| { |
| convertInfo.SetDataType(DataType::Float16); |
| input->GetConnection()->SetTensorInfo(convertInfo); |
| } |
| } |
| } |
| |
| // change outputs to DataType::Float16 |
| for (auto&& output = layer.BeginOutputSlots(); output != layer.EndOutputSlots(); ++output) |
| { |
| TensorInfo convertInfo = output->GetTensorInfo(); |
| if (convertInfo.GetDataType() == DataType::Float32) |
| { |
| convertInfo.SetDataType(DataType::Float16); |
| output->SetTensorInfo(convertInfo); |
| } |
| } |
| } |
| } |
| |
| protected: |
| ConvertFp32NetworkToFp16Impl() = default; |
| ~ConvertFp32NetworkToFp16Impl() = default; |
| }; |
| |
| using Fp32NetworkToFp16Converter = OptimizeForType<Layer, ConvertFp32NetworkToFp16Impl>; |
| |
| } // namespace optimizations |
| } // namespace armnn |