blob: 60cbf27bde8bacbd2c7464dd5c864406d54f32b0 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#pragma once
7
8#include "Graph.hpp"
9
10namespace armnn
11{
12
13inline std::vector<ConvertFp16ToFp32Layer*> InsertConvertFp16ToFp32LayersBefore(Graph& graph, Layer& layer)
14{
15 std::vector<ConvertFp16ToFp32Layer*> convertLayers;
16 convertLayers.reserve(layer.GetNumInputSlots());
17
18 for (auto&& inputSlot = layer.BeginInputSlots(); inputSlot != layer.EndInputSlots(); ++inputSlot)
19 {
20 // Insert FP16 to FP32 converter layer before the layer
21 const std::string name =
22 std::string("convert_fp16_to_fp32-" + std::to_string(inputSlot->GetSlotIndex()) + "-") + layer.GetName();
23 ConvertFp16ToFp32Layer* convertLayer =
24 graph.InsertNewLayer<ConvertFp16ToFp32Layer>(*inputSlot, name.c_str());
25
26 // Sets output tensor info for the convert layer
27 TensorInfo convertInfo = convertLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
28 convertInfo.SetDataType(DataType::Float32);
29
30 convertLayer->GetOutputSlot().SetTensorInfo(convertInfo);
31
32 convertLayers.emplace_back(convertLayer);
33 }
34
35 // Sets the output tensor info for the unsupported layer
36 auto UpdateTensorInfo = [](auto& outputSlot)
37 {
38 // Copy original tensor info and change data type to FP32
39 TensorInfo newTensorInfo = outputSlot.GetTensorInfo();
40 newTensorInfo.SetDataType(DataType::Float32);
41
42 outputSlot.SetTensorInfo(newTensorInfo);
43 };
44
45 std::for_each(layer.BeginOutputSlots(), layer.EndOutputSlots(), UpdateTensorInfo);
46
47 return convertLayers;
48}
49
50inline std::vector<ConvertFp32ToFp16Layer*> InsertConvertFp32ToFp16LayersAfter(Graph& graph, Layer& layer)
51{
52 std::vector<ConvertFp32ToFp16Layer*> convertLayers;
53 convertLayers.reserve(layer.GetNumOutputSlots());
54
55 int index = 0;
56 // Change outputs to DataType::Float16
57 for (auto&& outputSlot = layer.BeginOutputSlots(); outputSlot != layer.EndOutputSlots(); ++outputSlot)
58 {
59 BOOST_ASSERT(outputSlot->GetTensorInfo().GetDataType() == DataType::Float32);
60
61 // Insert FP32 to FP16 converter layer after the layer
62 const std::string name =
63 std::string("convert_fp32_to_fp16-" + std::to_string(index++) + "-") + layer.GetName();
64 ConvertFp32ToFp16Layer* convertLayer =
65 graph.InsertNewLayer<ConvertFp32ToFp16Layer>(*outputSlot, name.c_str());
66
67 // Sets output tensor info for the convert layer.
68 TensorInfo convertInfo = convertLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
69 convertInfo.SetDataType(DataType::Float16);
70
71 convertLayer->GetOutputSlot().SetTensorInfo(convertInfo);
72
73 convertLayers.emplace_back(convertLayer);
74 }
75
76 return convertLayers;
77}
78
79} //namespace armnn