blob: a07ac8827ecd1c96d5cdae7f64abcb69ea15cb02 [file] [log] [blame]
Derek Lamberti27d83072019-02-05 16:00:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/INetwork.hpp>
Jim Flynnf92dfce2019-05-02 11:33:25 +01009#include <armnnQuantizer/INetworkQuantizer.hpp>
10#include <armnn/IRuntime.hpp>
Derek Lamberti27d83072019-02-05 16:00:08 +000011#include <armnn/Types.hpp>
Jim Flynnf92dfce2019-05-02 11:33:25 +010012#include <armnn/Optional.hpp>
Derek Lamberti27d83072019-02-05 16:00:08 +000013
Finn Williamsb454c5c2021-02-09 15:56:23 +000014#include "DynamicQuantizationStrategy.hpp"
Derek Lamberti8a4ca102019-02-08 17:54:20 +000015#include "RangeTracker.hpp"
Matteo Martincigha8d572d2019-02-07 17:51:09 +000016
Derek Lamberti27d83072019-02-05 16:00:08 +000017namespace armnn
18{
19
20class NetworkQuantizer : public INetworkQuantizer
21{
22public:
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +000023 NetworkQuantizer(INetwork* inputNetwork, const QuantizerOptions& options)
Jim Flynnf92dfce2019-05-02 11:33:25 +010024 : m_InputNetwork(inputNetwork),
25 m_NetworkId(0),
26 m_Runtime(nullptr, &IRuntime::Destroy),
27 m_RefineCount(0),
28 m_Options(options) {}
Derek Lamberti27d83072019-02-05 16:00:08 +000029
Matteo Martincigha8d572d2019-02-07 17:51:09 +000030 void OverrideInputRange(LayerBindingId layerId, float min, float max) override;
Nina Drozd59e15b02019-04-25 15:45:20 +010031 void Refine(const InputTensors& inputTensors) override;
Jim Flynnf92dfce2019-05-02 11:33:25 +010032
33 // Required for testing? Need some way to get min/max in RangeTracker (m_Ranges)
34 std::pair<float, float> GetMinMaxRange(LayerGuid guid, unsigned int idx) { return m_Ranges.GetRange(guid, idx); }
Derek Lamberti27d83072019-02-05 16:00:08 +000035 INetworkPtr ExportNetwork() override;
36
37private:
Derek Lamberti8a4ca102019-02-08 17:54:20 +000038 /// Original input network to quantize
Derek Lamberti27d83072019-02-05 16:00:08 +000039 INetwork* m_InputNetwork;
Matteo Martincigha8d572d2019-02-07 17:51:09 +000040
Jim Flynnf92dfce2019-05-02 11:33:25 +010041 NetworkId m_NetworkId;
42
43 // if we are run in dynamic mode this unique pointer will hold
44 // the runtime between invocations of the Refine method.
45 IRuntimePtr m_Runtime;
46
Finn Williamsb454c5c2021-02-09 15:56:23 +000047 Optional<DynamicQuantizationStrategy> m_DynamicQuantizationStrategy;
Jim Flynnf92dfce2019-05-02 11:33:25 +010048
49 // counts the number of times refine is called
50 unsigned int m_RefineCount;
51
Matteo Martincigha8d572d2019-02-07 17:51:09 +000052 /// Mapping from Guid to an array of ranges for outputs
Derek Lamberti8a4ca102019-02-08 17:54:20 +000053 RangeTracker m_Ranges;
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +000054
55 /// Options for the NetworkQuantizer
56 QuantizerOptions m_Options;
Jim Flynnf92dfce2019-05-02 11:33:25 +010057
58 std::pair<float, float> FindMinMax(ITensorHandle* tensorHandle);
Derek Lamberti27d83072019-02-05 16:00:08 +000059};
60
Matteo Martincigha8d572d2019-02-07 17:51:09 +000061} //namespace armnn