blob: f7829590204a94172300d0b0a55d5b60e3ce5ea2 [file] [log] [blame]
Finn Williamsb454c5c2021-02-09 15:56:23 +00001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Network.hpp"
7#include "NetworkQuantizerUtils.hpp"
8#include "StaticRangeStrategy.hpp"
9
10#include <armnn/utility/NumericCast.hpp>
11#include <armnn/utility/PolymorphicDowncast.hpp>
12
13namespace armnn
14{
15class QuantizerStrategy : public IStrategy
16{
17public :
18 QuantizerStrategy(const RangeTracker& rangeTracker,
19 const IQuantizationScheme* quantizationScheme,
20 bool preserveType);
21
22 ~QuantizerStrategy() = default;
23
24 void ExecuteStrategy(const armnn::IConnectableLayer* layer,
25 const BaseDescriptor& descriptor,
26 const std::vector<armnn::ConstTensor>& constants,
27 const char* name,
28 const armnn::LayerBindingId id) override;
29
30 /// Extract the quantized network
31 INetworkPtr RetrieveFinalNetwork() { return std::move(m_QuantizedNetwork); }
32
33private:
34 /// Connects the layer to preceeding layers and sets the quantization parameters based on recorded ranges
35 void SetQuantizedInputConnections(const IConnectableLayer* srcLayer, IConnectableLayer* quantizedLayer);
36
37 /// Record the guids so we can easily find the layers later
38 void RecordLayer(const IConnectableLayer* srcLayer, IConnectableLayer* qLayer);
39
40 /// Sets the bias quantization scale based on input and weight scales
41 ConstTensor CreateQuantizedBias(const IConnectableLayer* srcLayer,
42 const ConstTensor& weights,
43 const Optional<ConstTensor>& biases,
44 std::vector<int32_t>& weightsBacking);
45
46 /// Reference to the static range visitor used to retrieve the quantization ranges
47 const RangeTracker& m_Ranges;
48
49 /// Quantized version of the model we are building up
50 INetworkPtr m_QuantizedNetwork;
51
52 /// Mapping from input network guids to quantized network guids
53 std::unordered_map<LayerGuid, LayerGuid> m_OriginalToQuantizedGuidMap;
54
55 /// Mapping from guid to layer in quantized network
56 std::unordered_map<LayerGuid, IConnectableLayer*> m_QuantizedGuidToLayerMap;
57
58 const IQuantizationScheme* m_QuantizationScheme;
59
60 const bool m_PreserveType;
61};
62
63} //namespace armnn