Derek Lamberti | 27d8307 | 2019-02-05 16:00:08 +0000 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #pragma once |
| 7 | |
| 8 | #include <armnn/INetwork.hpp> |
Jim Flynn | f92dfce | 2019-05-02 11:33:25 +0100 | [diff] [blame] | 9 | #include <armnnQuantizer/INetworkQuantizer.hpp> |
| 10 | #include <armnn/IRuntime.hpp> |
Derek Lamberti | 27d8307 | 2019-02-05 16:00:08 +0000 | [diff] [blame] | 11 | #include <armnn/Types.hpp> |
Jim Flynn | f92dfce | 2019-05-02 11:33:25 +0100 | [diff] [blame] | 12 | #include <armnn/Optional.hpp> |
Derek Lamberti | 27d8307 | 2019-02-05 16:00:08 +0000 | [diff] [blame] | 13 | |
Finn Williams | b454c5c | 2021-02-09 15:56:23 +0000 | [diff] [blame] | 14 | #include "DynamicQuantizationStrategy.hpp" |
Derek Lamberti | 8a4ca10 | 2019-02-08 17:54:20 +0000 | [diff] [blame] | 15 | #include "RangeTracker.hpp" |
Matteo Martincigh | a8d572d | 2019-02-07 17:51:09 +0000 | [diff] [blame] | 16 | |
Derek Lamberti | 27d8307 | 2019-02-05 16:00:08 +0000 | [diff] [blame] | 17 | namespace armnn |
| 18 | { |
| 19 | |
| 20 | class NetworkQuantizer : public INetworkQuantizer |
| 21 | { |
| 22 | public: |
Nattapat Chaimanowong | 7ac07f3 | 2019-03-20 11:51:14 +0000 | [diff] [blame] | 23 | NetworkQuantizer(INetwork* inputNetwork, const QuantizerOptions& options) |
Jim Flynn | f92dfce | 2019-05-02 11:33:25 +0100 | [diff] [blame] | 24 | : m_InputNetwork(inputNetwork), |
| 25 | m_NetworkId(0), |
| 26 | m_Runtime(nullptr, &IRuntime::Destroy), |
| 27 | m_RefineCount(0), |
| 28 | m_Options(options) {} |
Derek Lamberti | 27d8307 | 2019-02-05 16:00:08 +0000 | [diff] [blame] | 29 | |
Matteo Martincigh | a8d572d | 2019-02-07 17:51:09 +0000 | [diff] [blame] | 30 | void OverrideInputRange(LayerBindingId layerId, float min, float max) override; |
Nina Drozd | 59e15b0 | 2019-04-25 15:45:20 +0100 | [diff] [blame] | 31 | void Refine(const InputTensors& inputTensors) override; |
Jim Flynn | f92dfce | 2019-05-02 11:33:25 +0100 | [diff] [blame] | 32 | |
| 33 | // Required for testing? Need some way to get min/max in RangeTracker (m_Ranges) |
| 34 | std::pair<float, float> GetMinMaxRange(LayerGuid guid, unsigned int idx) { return m_Ranges.GetRange(guid, idx); } |
Derek Lamberti | 27d8307 | 2019-02-05 16:00:08 +0000 | [diff] [blame] | 35 | INetworkPtr ExportNetwork() override; |
| 36 | |
| 37 | private: |
Derek Lamberti | 8a4ca10 | 2019-02-08 17:54:20 +0000 | [diff] [blame] | 38 | /// Original input network to quantize |
Derek Lamberti | 27d8307 | 2019-02-05 16:00:08 +0000 | [diff] [blame] | 39 | INetwork* m_InputNetwork; |
Matteo Martincigh | a8d572d | 2019-02-07 17:51:09 +0000 | [diff] [blame] | 40 | |
Jim Flynn | f92dfce | 2019-05-02 11:33:25 +0100 | [diff] [blame] | 41 | NetworkId m_NetworkId; |
| 42 | |
| 43 | // if we are run in dynamic mode this unique pointer will hold |
| 44 | // the runtime between invocations of the Refine method. |
| 45 | IRuntimePtr m_Runtime; |
| 46 | |
Finn Williams | b454c5c | 2021-02-09 15:56:23 +0000 | [diff] [blame] | 47 | Optional<DynamicQuantizationStrategy> m_DynamicQuantizationStrategy; |
Jim Flynn | f92dfce | 2019-05-02 11:33:25 +0100 | [diff] [blame] | 48 | |
| 49 | // counts the number of times refine is called |
| 50 | unsigned int m_RefineCount; |
| 51 | |
Matteo Martincigh | a8d572d | 2019-02-07 17:51:09 +0000 | [diff] [blame] | 52 | /// Mapping from Guid to an array of ranges for outputs |
Derek Lamberti | 8a4ca10 | 2019-02-08 17:54:20 +0000 | [diff] [blame] | 53 | RangeTracker m_Ranges; |
Nattapat Chaimanowong | 7ac07f3 | 2019-03-20 11:51:14 +0000 | [diff] [blame] | 54 | |
| 55 | /// Options for the NetworkQuantizer |
| 56 | QuantizerOptions m_Options; |
Jim Flynn | f92dfce | 2019-05-02 11:33:25 +0100 | [diff] [blame] | 57 | |
| 58 | std::pair<float, float> FindMinMax(ITensorHandle* tensorHandle); |
Derek Lamberti | 27d8307 | 2019-02-05 16:00:08 +0000 | [diff] [blame] | 59 | }; |
| 60 | |
Matteo Martincigh | a8d572d | 2019-02-07 17:51:09 +0000 | [diff] [blame] | 61 | } //namespace armnn |