blob: 3a97630ccfe097f778bd0090ef2787ece7615946 [file] [log] [blame]
Nina Drozd59e15b02019-04-25 15:45:20 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <map>
9#include "QuantizationInput.hpp"
10#include "armnn/LayerVisitorBase.hpp"
11#include "armnn/Tensor.hpp"
12
13namespace armnnQuantizer
14{
15
16/// QuantizationDataSet is a structure which is created after parsing a quantization CSV file.
17/// It contains records of filenames which contain refinement data per pass ID for binding ID.
18class QuantizationDataSet
19{
20 using QuantizationInputs = std::vector<armnnQuantizer::QuantizationInput>;
21public:
22
23 using iterator = QuantizationInputs::iterator;
24 using const_iterator = QuantizationInputs::const_iterator;
25
26 QuantizationDataSet();
27 QuantizationDataSet(std::string csvFilePath);
28 ~QuantizationDataSet();
29 bool IsEmpty() const {return m_QuantizationInputs.empty();}
30
31 iterator begin() { return m_QuantizationInputs.begin(); }
32 iterator end() { return m_QuantizationInputs.end(); }
33 const_iterator begin() const { return m_QuantizationInputs.begin(); }
34 const_iterator end() const { return m_QuantizationInputs.end(); }
35 const_iterator cbegin() const { return m_QuantizationInputs.cbegin(); }
36 const_iterator cend() const { return m_QuantizationInputs.cend(); }
37
38private:
39 void ParseCsvFile();
40
41 QuantizationInputs m_QuantizationInputs;
42 std::string m_CsvFilePath;
43};
44
45/// Visitor class implementation to gather the TensorInfo for LayerBindingID for creation of ConstTensor for Refine.
46class InputLayerVisitor : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
47{
48public:
49 void VisitInputLayer(const armnn::IConnectableLayer *layer, armnn::LayerBindingId id, const char* name);
50 armnn::TensorInfo GetTensorInfo(armnn::LayerBindingId);
51private:
52 std::map<armnn::LayerBindingId, armnn::TensorInfo> m_TensorInfos;
53};
54
55} // namespace armnnQuantizer