Nina Drozd | 59e15b0 | 2019-04-25 15:45:20 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #pragma once |
| 7 | |
| 8 | #include <map> |
| 9 | #include "QuantizationInput.hpp" |
| 10 | #include "armnn/LayerVisitorBase.hpp" |
| 11 | #include "armnn/Tensor.hpp" |
| 12 | |
| 13 | namespace armnnQuantizer |
| 14 | { |
| 15 | |
| 16 | /// QuantizationDataSet is a structure which is created after parsing a quantization CSV file. |
| 17 | /// It contains records of filenames which contain refinement data per pass ID for binding ID. |
| 18 | class QuantizationDataSet |
| 19 | { |
| 20 | using QuantizationInputs = std::vector<armnnQuantizer::QuantizationInput>; |
| 21 | public: |
| 22 | |
| 23 | using iterator = QuantizationInputs::iterator; |
| 24 | using const_iterator = QuantizationInputs::const_iterator; |
| 25 | |
| 26 | QuantizationDataSet(); |
| 27 | QuantizationDataSet(std::string csvFilePath); |
| 28 | ~QuantizationDataSet(); |
| 29 | bool IsEmpty() const {return m_QuantizationInputs.empty();} |
| 30 | |
| 31 | iterator begin() { return m_QuantizationInputs.begin(); } |
| 32 | iterator end() { return m_QuantizationInputs.end(); } |
| 33 | const_iterator begin() const { return m_QuantizationInputs.begin(); } |
| 34 | const_iterator end() const { return m_QuantizationInputs.end(); } |
| 35 | const_iterator cbegin() const { return m_QuantizationInputs.cbegin(); } |
| 36 | const_iterator cend() const { return m_QuantizationInputs.cend(); } |
| 37 | |
| 38 | private: |
| 39 | void ParseCsvFile(); |
| 40 | |
| 41 | QuantizationInputs m_QuantizationInputs; |
| 42 | std::string m_CsvFilePath; |
| 43 | }; |
| 44 | |
| 45 | /// Visitor class implementation to gather the TensorInfo for LayerBindingID for creation of ConstTensor for Refine. |
| 46 | class InputLayerVisitor : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy> |
| 47 | { |
| 48 | public: |
| 49 | void VisitInputLayer(const armnn::IConnectableLayer *layer, armnn::LayerBindingId id, const char* name); |
| 50 | armnn::TensorInfo GetTensorInfo(armnn::LayerBindingId); |
| 51 | private: |
| 52 | std::map<armnn::LayerBindingId, armnn::TensorInfo> m_TensorInfos; |
| 53 | }; |
| 54 | |
| 55 | } // namespace armnnQuantizer |