blob: 1bfe84bc199ca27072256f805a2f0012ed945427 [file] [log] [blame]
Nina Drozd59e15b02019-04-25 15:45:20 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <map>
9#include <armnn/Types.hpp>
Jim Flynnf92dfce2019-05-02 11:33:25 +010010#include <armnnQuantizer/INetworkQuantizer.hpp>
Nina Drozd59e15b02019-04-25 15:45:20 +010011
12namespace armnnQuantizer
13{
14
15/// QuantizationInput for specific pass ID, can list a corresponding raw data file for each LayerBindingId.
16class QuantizationInput
17{
18public:
19
20 /// Constructor for QuantizationInput
21 QuantizationInput(const unsigned int passId,
22 const armnn::LayerBindingId bindingId,
23 const std::string fileName);
24
25 QuantizationInput(const QuantizationInput& other);
26
27 // Add binding ID to image tensor filepath entry
28 void AddEntry(const armnn::LayerBindingId bindingId, const std::string fileName);
29
30 // Retrieve tensor data for entry with provided binding ID
31 std::vector<float> GetDataForEntry(const armnn::LayerBindingId bindingId) const;
32
33 /// Retrieve Layer Binding IDs for this QuantizationInput.
34 std::vector<armnn::LayerBindingId> GetLayerBindingIds() const;
35
36 /// Get number of inputs for this QuantizationInput.
37 unsigned long GetNumberOfInputs() const;
38
39 /// Retrieve Pass ID for this QuantizationInput.
40 unsigned int GetPassId() const;
41
42 /// Retrieve filename path for specified Layer Binding ID.
43 std::string GetFileName(const armnn::LayerBindingId bindingId) const;
44
45 /// Destructor
46 ~QuantizationInput() noexcept;
47
48private:
49 unsigned int m_PassId;
50 std::map<armnn::LayerBindingId, std::string> m_LayerBindingIdToFileName;
51
52};
53
54}