blob: 16314dc0d0711d6be47c3a8fb32a33ba9d96b287 [file] [log] [blame]
Teresa Charlin5841c742022-05-15 14:07:05 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "Optimization.hpp"
8#include "NetworkUtils.hpp"
9
10namespace armnn
11{
12namespace optimizations
13{
14
15class ConvertConstDequantisationLayersToConstLayersImpl
16{
17public:
18 void Run(Graph& graph, InputSlot& connection) const
19 {
20 Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
21 Layer& child = connection.GetOwningLayer();
22
23 ARMNN_ASSERT(base.GetType() == LayerType::Constant);
24 ARMNN_ASSERT(child.GetType() == LayerType::Dequantize);
25
26 ReplaceConstDequantisationLayer(graph,
27 PolymorphicDowncast<ConstantLayer*>(&base),
28 PolymorphicDowncast<DequantizeLayer*>(&child));
29
30 }
31protected:
32 ConvertConstDequantisationLayersToConstLayersImpl() = default;
33 ~ConvertConstDequantisationLayersToConstLayersImpl() = default;
34private:
35
36 static void ReplaceConstDequantisationLayer(Graph& graph,
37 ConstantLayer* constantLayer,
38 DequantizeLayer* dequantizeLayer)
39 {
40 IgnoreUnused(graph);
41 /**
42 * This optimisation is to find situations where a constant set of inputs is being provided to a Dequantization
43 * layer. In this case we don't want the overhead of Dequantizing the values on every inference, instead we
44 * want to Dequantize them once and store them in a Const layer to be used everytime as they will not change.
45 */
46 TensorInfo constantInfo = constantLayer->GetOutputSlot(0).GetTensorInfo();
47 TensorInfo inputDequantizeInfo = dequantizeLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
48 TensorInfo outputDequantizeInfo = dequantizeLayer->GetOutputSlot(0).GetTensorInfo();
49
50 ARMNN_ASSERT(constantLayer->GetNumOutputSlots() == 1);
51 auto numConnections = constantLayer->GetOutputSlot(0).GetNumConnections();
52
53 std::vector<float> newValues(outputDequantizeInfo.GetNumElements());
54 if (constantInfo.GetDataType() == DataType::Float16 &&
55 inputDequantizeInfo.GetDataType() == DataType::Float16 &&
56 outputDequantizeInfo.GetDataType() == DataType::Float32)
57 {
58 armnnUtils::FloatingPointConverter::ConvertFloat16To32(constantLayer->m_LayerOutput->Map(true),
59 outputDequantizeInfo.GetNumElements(),
60 newValues.data());
61 }
62 else if (constantInfo.GetDataType() == DataType::QAsymmS8 &&
63 inputDequantizeInfo.GetDataType() == DataType::QAsymmS8 &&
64 outputDequantizeInfo.GetDataType() == DataType::Float32)
65 {
66 ConvertInt8To32(constantLayer->m_LayerOutput->Map(true),
67 outputDequantizeInfo.GetNumElements(),
68 newValues.data());
69 }
70
71 TensorInfo newInfo = outputDequantizeInfo;
72 newInfo.SetConstant(true);
73 ConstTensor newInput(newInfo, newValues);
74 constantLayer->m_LayerOutput.reset(new ScopedTensorHandle(newInput));
75
76 // Moves connections in dequantize output to the constant layer.
77 // Dequantize layer will be removed if left unconnected.
78 dequantizeLayer->GetOutputSlot().MoveAllConnections(constantLayer->GetOutputSlot());
79
80 // Updating the output tensor
81 constantLayer->GetOutputSlot(0).SetTensorInfo(newInfo);
82 ARMNN_ASSERT(constantLayer->GetOutputSlot(0).GetTensorInfo().IsConstant() == true);
83
84 // Set isConstant to true in all input tensor infos where constantLayer is now connected to
85 for (unsigned int i = numConnections; i < constantLayer->GetOutputSlot(0).GetNumConnections(); ++i)
86 {
87 auto info = constantLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer().GetInputSlot(0)
88 .GetConnectedOutputSlot()->GetTensorInfo();
89 info.SetConstant();
90 constantLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer().GetInputSlot(0)
91 .GetConnectedOutputSlot()->SetTensorInfo(info);
92 }
93 }
94
95
96static void ConvertInt8To32(const void* srcInt8Buffer,
97 size_t numElements,
98 float* dstFloat32Buffer)
99{
100 ARMNN_ASSERT(srcInt8Buffer != nullptr);
101 ARMNN_ASSERT(dstFloat32Buffer != nullptr);
102
103 const auto* pInt8 = static_cast<const int8_t*>(srcInt8Buffer);
104
105 for (size_t i = 0; i < numElements; ++i)
106 {
107 dstFloat32Buffer[i] = pInt8[i];
108 }
109}
110
111};
112
113using ConvertConstDequantisationLayersToConstLayers
114 = OptimizeForConnection<ConstantLayer,
115 DequantizeLayer,
116 ConvertConstDequantisationLayersToConstLayersImpl>;
117
118} // namespace optimizations
119} // namespace armnn