blob: 196a3aab1db6d32912048a3adba6c3938e4c676d [file] [log] [blame]
Matteo Martincigha8d572d2019-02-07 17:51:09 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "NetworkQuantizer.hpp"
Mike Kelly8c1701a2019-02-11 17:01:27 +00009#include "armnn/LayerVisitorBase.hpp"
Derek Lamberti8a4ca102019-02-08 17:54:20 +000010#include "RangeTracker.hpp"
Matteo Martincigha8d572d2019-02-07 17:51:09 +000011
12#include <unordered_map>
13
14namespace armnn
15{
Finn Williamsb454c5c2021-02-09 15:56:23 +000016class OverrideInputRangeStrategy : public IStrategy
17{
18private:
19 using MinMaxRange = RangeTracker::MinMaxRange;
20public :
21 OverrideInputRangeStrategy(RangeTracker& ranges,
22 LayerBindingId layerId,
23 const MinMaxRange& minMaxRange)
24 : m_Ranges(ranges)
25 , m_LayerId(layerId)
26 , m_MinMaxRange(minMaxRange){}
27
28 ~OverrideInputRangeStrategy() = default;
29
30 void ExecuteStrategy(const armnn::IConnectableLayer* layer,
31 const BaseDescriptor& descriptor,
32 const std::vector<armnn::ConstTensor>& constants,
33 const char* name,
34 const armnn::LayerBindingId id) override
35 {
36 IgnoreUnused(name, constants, id, descriptor);
37
38 switch (layer->GetType())
39 {
40 case armnn::LayerType::Input :
41 {
42 if (m_LayerId == id)
43 {
44 m_Ranges.SetRange(layer, 0, m_MinMaxRange.first, m_MinMaxRange.second);
45 }
46 break;
47 }
48 default:
49 {
50 std::cout << "dont know this one" << std::endl;
51 }
52 }
53 }
54
55private:
56 /// Mapping from a layer Guid to an array of ranges for outputs
57 RangeTracker& m_Ranges;
58
59 /// The id of the input layer of which to override the input range
60 LayerBindingId m_LayerId;
61
62 /// The new input range to be applied to the input layer
63 MinMaxRange m_MinMaxRange;
64};
65
66
Matteo Martincigha8d572d2019-02-07 17:51:09 +000067
68/// Visitor object for overriding the input range of the quantized input layers in a network
69class OverrideInputRangeVisitor : public LayerVisitorBase<VisitorNoThrowPolicy>
70{
71private:
Matteo Martincigh5b2159e2019-02-11 13:24:38 +000072 using MinMaxRange = RangeTracker::MinMaxRange;
Matteo Martincigha8d572d2019-02-07 17:51:09 +000073
74public:
Derek Lamberti8a4ca102019-02-08 17:54:20 +000075 OverrideInputRangeVisitor(RangeTracker& ranges,
Matteo Martincigha8d572d2019-02-07 17:51:09 +000076 LayerBindingId layerId,
77 const MinMaxRange& minMaxRange);
78 ~OverrideInputRangeVisitor() = default;
79
Matteo Martincigh9c5d33a2019-02-07 17:52:41 +000080 void VisitInputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name = nullptr) override;
Matteo Martincigha8d572d2019-02-07 17:51:09 +000081
82private:
Matteo Martincigha8d572d2019-02-07 17:51:09 +000083 /// Mapping from a layer Guid to an array of ranges for outputs
Derek Lamberti8a4ca102019-02-08 17:54:20 +000084 RangeTracker& m_Ranges;
Matteo Martincigha8d572d2019-02-07 17:51:09 +000085
86 /// The id of the input layer of which to override the input range
87 LayerBindingId m_LayerId;
88
89 /// The new input range to be applied to the input layer
90 MinMaxRange m_MinMaxRange;
91};
92
93} // namespace armnn