blob: fd4486bc3186de6c1848b89c15c79a7aa4c086c7 [file] [log] [blame]
Derek Lamberti27d83072019-02-05 16:00:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +00006#include "NetworkQuantizer.hpp"
7#include "NetworkQuantizerUtils.hpp"
Derek Lamberti27d83072019-02-05 16:00:08 +00008#include "Graph.hpp"
9#include "Layer.hpp"
10#include "Network.hpp"
Finn Williamsb454c5c2021-02-09 15:56:23 +000011#include "DynamicQuantizationStrategy.hpp"
12#include "StaticRangeStrategy.hpp"
13#include "QuantizerStrategy.hpp"
Matteo Martincigha8d572d2019-02-07 17:51:09 +000014#include "OverrideInputRangeVisitor.hpp"
Derek Lamberti27d83072019-02-05 16:00:08 +000015
Matteo Martincighe011d202019-11-28 11:35:47 +000016#include <TensorIOUtils.hpp>
17
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000018#include <armnn/ILayerVisitor.hpp>
19#include <armnn/INetwork.hpp>
20#include <armnn/Tensor.hpp>
21#include <armnn/Types.hpp>
22
Matteo Martincighe011d202019-11-28 11:35:47 +000023#include <armnnUtils/TensorUtils.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010024#include <armnn/utility/PolymorphicDowncast.hpp>
Derek Lamberti27d83072019-02-05 16:00:08 +000025
James Ward6d9f5c52020-09-28 11:56:35 +010026#include <mapbox/variant.hpp>
Jim Flynnf92dfce2019-05-02 11:33:25 +010027
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000028#include <vector>
29#include <cmath>
Jim Flynnf92dfce2019-05-02 11:33:25 +010030
Derek Lamberti27d83072019-02-05 16:00:08 +000031namespace armnn
32{
33
Finn Williamsf806c4d2021-02-22 15:13:12 +000034using TContainer =
35 mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>, std::vector<int8_t>>;
Jim Flynnf92dfce2019-05-02 11:33:25 +010036
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +000037INetworkQuantizer* INetworkQuantizer::CreateRaw(INetwork* inputNetwork, const QuantizerOptions& options)
Derek Lamberti27d83072019-02-05 16:00:08 +000038{
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +000039 return new NetworkQuantizer(inputNetwork, options);
Derek Lamberti27d83072019-02-05 16:00:08 +000040}
41
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +000042INetworkQuantizerPtr INetworkQuantizer::Create(INetwork* inputNetwork, const QuantizerOptions& options)
Derek Lamberti27d83072019-02-05 16:00:08 +000043{
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +000044 return INetworkQuantizerPtr(CreateRaw(inputNetwork, options), &INetworkQuantizer::Destroy);
Derek Lamberti27d83072019-02-05 16:00:08 +000045}
46
47void INetworkQuantizer::Destroy(INetworkQuantizer *quantizer)
48{
Jan Eilersbb446e52020-04-02 13:56:54 +010049 delete PolymorphicDowncast<NetworkQuantizer*>(quantizer);
Derek Lamberti27d83072019-02-05 16:00:08 +000050}
51
Matteo Martincigha8d572d2019-02-07 17:51:09 +000052void NetworkQuantizer::OverrideInputRange(LayerBindingId layerId, float min, float max)
53{
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000054 const Graph& graph = m_InputNetwork->pNetworkImpl->GetGraph();
Matteo Martincigha8d572d2019-02-07 17:51:09 +000055 auto inputLayers = graph.GetInputLayers();
56
57 // Walk the input layers of the graph and override the quantization parameters of the one with the given id
Derek Lamberti8a4ca102019-02-08 17:54:20 +000058 OverrideInputRangeVisitor overrideInputRangeVisitor(m_Ranges, layerId, RangeTracker::MinMaxRange{min, max});
Matteo Martincigha8d572d2019-02-07 17:51:09 +000059 VisitLayers(inputLayers, overrideInputRangeVisitor);
60}
61
Nina Drozd59e15b02019-04-25 15:45:20 +010062void NetworkQuantizer::Refine(const InputTensors& inputTensors)
63{
Finn Williamsb454c5c2021-02-09 15:56:23 +000064 // The first time Refine is called the m_Runtime and the DynamicQuantizationStrategy
Jim Flynnf92dfce2019-05-02 11:33:25 +010065 // will not have been created. Need to get the environment set up, Runtime loaded,
Finn Williamsb454c5c2021-02-09 15:56:23 +000066 // DynamicQuantizationStrategy created and run over the network to initialise itself
Jim Flynnf92dfce2019-05-02 11:33:25 +010067 // and the RangeTracker the Debug callback registered and an initial inference
68 // done to set up the first min/max values
69 if (!m_Runtime)
70 {
71 m_RefineCount = 0;
72 m_Ranges.SetDynamicMode(true);
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000073 const Graph& cGraph = m_InputNetwork->pNetworkImpl->GetGraph().TopologicalSort();
Jim Flynnf92dfce2019-05-02 11:33:25 +010074
Finn Williamsb454c5c2021-02-09 15:56:23 +000075 // need to insert Debug layers in the DynamicQuantizationStrategy
Jim Flynnf92dfce2019-05-02 11:33:25 +010076 Graph& graph = const_cast<Graph&>(cGraph);
77
78 // Initialize RangeTracker to the default values for each layer.
79 // The default values are overwritten by the min/max that is
80 // recorded during the first dataset min/max calibration. This
81 // initialisation is only required for the first call of Refine().
Finn Williamsb454c5c2021-02-09 15:56:23 +000082 m_DynamicQuantizationStrategy = DynamicQuantizationStrategy(m_Ranges, graph);
83 ApplyStrategyToLayers(cGraph, m_DynamicQuantizationStrategy.value());
Jim Flynnf92dfce2019-05-02 11:33:25 +010084
85 IRuntime::CreationOptions options;
86 m_Runtime = IRuntime::Create(options);
87
88 // Optimize network - debug already enabled for layers that require quantization
89 OptimizerOptions optimizerOptions(false, false);
90 std::vector<BackendId> backends = {"CpuRef"};
91 IOptimizedNetworkPtr optimizedNet = Optimize(*m_InputNetwork,
92 backends,
93 m_Runtime->GetDeviceSpec(),
94 optimizerOptions);
95
96 m_Runtime->LoadNetwork(m_NetworkId, std::move(optimizedNet));
97
98 // Debug callback function to refine min/max in RangeTracker
99 auto rangeTrackerCallback = [&](LayerGuid guid, unsigned int slotIndex, ITensorHandle *tensorHandle) {
100 // Get min/max pair from tensor data
101 std::pair<float, float> minMax = armnnUtils::FindMinMax(tensorHandle);
102
103 // For first calibration dataset, set min/max range in RangeTracker to
104 // min/max ranges gathered during inference
105 if (m_RefineCount == 0)
106 {
107 m_Ranges.ResetMinMax(guid, slotIndex, minMax.first, minMax.second);
108 }
109 else
110 {
111 // For every other calibration dataset, only set min/max range if the
112 // values gathered are less than / greater than originally recorded.
113 m_Ranges.RefineMin(guid, slotIndex, minMax.first);
114 m_Ranges.RefineMax(guid, slotIndex, minMax.second);
115 }
116 };
117
118 m_Runtime->RegisterDebugCallback(m_NetworkId, rangeTrackerCallback);
119 }
120
121 // Create output tensor for EnqueueWorkload
122 std::vector<armnn::BindingPointInfo> outputBindings;
Finn Williamsb454c5c2021-02-09 15:56:23 +0000123 auto outputLayers = m_DynamicQuantizationStrategy.value().GetOutputLayers();
Jim Flynnf92dfce2019-05-02 11:33:25 +0100124 std::vector<TContainer> outputVectors;
125 for (auto outputLayerBindingId : outputLayers)
126 {
127 auto outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, outputLayerBindingId);
128 outputBindings.push_back(std::make_pair(outputLayerBindingId, outputTensorInfo));
129 outputVectors.push_back(std::vector<float>(outputTensorInfo.GetNumElements(), 0));
130 }
131 OutputTensors outputTensors = armnnUtils::MakeOutputTensors<TContainer>(outputBindings, outputVectors);
132
133 // Execute EnqueueWorkload with calibration image
134 m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
135 ++m_RefineCount;
Nina Drozd59e15b02019-04-25 15:45:20 +0100136}
137
Derek Lamberti27d83072019-02-05 16:00:08 +0000138INetworkPtr NetworkQuantizer::ExportNetwork()
139{
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000140 const Graph& graph = m_InputNetwork->pNetworkImpl->GetGraph().TopologicalSort();
Derek Lamberti27d83072019-02-05 16:00:08 +0000141
Jim Flynnf92dfce2019-05-02 11:33:25 +0100142 // Step 1) Walk the graph and populate default min/max values for
143 // intermediate tensors, only if Runtime does not exist (created
144 // if Refine has been called)
145 if (!m_Runtime)
146 {
147 m_Ranges.SetDynamicMode(false);
Finn Williamsb454c5c2021-02-09 15:56:23 +0000148 StaticRangeStrategy rangeStrategy(m_Ranges);
149 ApplyStrategyToLayers(graph, rangeStrategy);
Jim Flynnf92dfce2019-05-02 11:33:25 +0100150 }
151 else
152 {
153 // Set min/max range of non-calibrated layers to parent layer's range
Finn Williamsb454c5c2021-02-09 15:56:23 +0000154 m_DynamicQuantizationStrategy.value().VisitNonCalibratedLayers();
Jim Flynnf92dfce2019-05-02 11:33:25 +0100155 // now tear down the runtime and the dynamic visitor.
156 m_Runtime.reset(nullptr);
Finn Williamsb454c5c2021-02-09 15:56:23 +0000157 m_DynamicQuantizationStrategy = EmptyOptional();
Jim Flynnf92dfce2019-05-02 11:33:25 +0100158 m_RefineCount = 0;
159 }
Derek Lamberti27d83072019-02-05 16:00:08 +0000160
161 // Step 2) Convert input InputNetwork to Quantized InputNetwork
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +0000162 std::unique_ptr<IQuantizationScheme> quantizationScheme;
163 switch (m_Options.m_ActivationFormat)
164 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000165 case DataType::QAsymmU8:
Ryan OShea9add1202020-02-07 10:06:33 +0000166 quantizationScheme = std::make_unique<QAsymmU8QuantizationScheme>();
167 break;
168 case DataType::QAsymmS8:
169 quantizationScheme = std::make_unique<QAsymmS8QuantizationScheme>();
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +0000170 break;
Finn Williamsfd271062019-12-04 14:27:27 +0000171 case DataType::QSymmS8:
172 quantizationScheme = std::make_unique<QSymmS8QuantizationScheme>();
173 break;
Derek Lambertif90c56d2020-01-10 17:14:08 +0000174 case DataType::QSymmS16:
Nattapat Chaimanowong7ac07f32019-03-20 11:51:14 +0000175 quantizationScheme = std::make_unique<QSymm16QuantizationScheme>();
176 break;
177 default:
178 throw InvalidArgumentException("Unsupported quantization target");
179 }
180
Finn Williamsb454c5c2021-02-09 15:56:23 +0000181 QuantizerStrategy quantizerVisitor(m_Ranges, quantizationScheme.get(), m_Options.m_PreserveType);
182 ApplyStrategyToLayers(graph, quantizerVisitor);
Derek Lamberti27d83072019-02-05 16:00:08 +0000183
Jim Flynnf92dfce2019-05-02 11:33:25 +0100184 // clear the ranges
185 m_Ranges.Reset();
186
Derek Lamberti27d83072019-02-05 16:00:08 +0000187 return quantizerVisitor.RetrieveFinalNetwork();
188}
189
Matteo Martincigha8d572d2019-02-07 17:51:09 +0000190} //namespace armn