blob: aa77a4b56300187b0fda4d60ac808f965997bbdc [file] [log] [blame]
Finn Williamsb454c5c2021-02-09 15:56:23 +00001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "armnn/LayerVisitorBase.hpp"
9#include "RangeTracker.hpp"
10#include "layers/DebugLayer.hpp"
11
12#include <armnn/INetwork.hpp>
13#include <armnnQuantizer/INetworkQuantizer.hpp>
14
15namespace armnn
16{
17
18/// Visitor class implementation to gather the TensorInfo for LayerBindingID for creation of ConstTensor for Refine.
19class DynamicQuantizationStrategy : public armnn::IStrategy
20{
21public:
22
23 DynamicQuantizationStrategy(RangeTracker& rangeTracker, Graph& graph);
24 ~DynamicQuantizationStrategy() = default;
25
26 virtual void ExecuteStrategy(const armnn::IConnectableLayer* layer,
27 const armnn::BaseDescriptor& descriptor,
28 const std::vector<armnn::ConstTensor>& constants,
29 const char* name,
30 const armnn::LayerBindingId id = 0) override;
31
32 const std::vector<armnn::LayerBindingId>& GetOutputLayers();
33 void VisitNonCalibratedLayers();
34 void FinishStrategy() override;
35
36
37private:
38 /// Set the range for an output slot on a layer
39 void SetRange(const IConnectableLayer* layer, unsigned int outputIdx, float min, float max);
40
41 void ForwardParentParameters(const IConnectableLayer* layer);
42
43 /// Mapping from a layer Guid to an array of ranges for outputs
44 RangeTracker& m_RangeTracker;
45
46 Graph& m_Graph;
47
48 std::vector<const IConnectableLayer*> m_LayersToCalibrate;
49 std::vector<const IConnectableLayer*> m_LayersNotToCalibrate;
50 std::vector<DebugLayer*> m_DebugLayers;
51
52 std::vector<armnn::LayerBindingId> m_OutputLayers;
53 void AddToCalibratedLayers(const IConnectableLayer* layer);
54 void AddToNonCalibratedLayers(const IConnectableLayer* layer);
55 void RemoveDebugLayers();
56
57
58};
59} //namespace armnn