blob: e78c5f02688275de8de0d8c69f2a230c1329e38e [file] [log] [blame]
telsoa015307bc12018-03-09 13:51:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
telsoa015307bc12018-03-09 13:51:08 +00004//
5
6#pragma once
7
surmeh01deb3bdb2018-07-05 12:06:04 +01008#include "ArmnnDriver.hpp"
arovir01b0717b52018-09-05 17:03:25 +01009#include "ConversionUtils.hpp"
telsoa01ce3e84a2018-08-31 09:31:35 +010010
telsoa015307bc12018-03-09 13:51:08 +000011#include <armnn/ArmNN.hpp>
telsoa015307bc12018-03-09 13:51:08 +000012
telsoa015307bc12018-03-09 13:51:08 +000013#include <set>
Nattapat Chaimanowongd5fd9762019-04-04 13:33:10 +010014#include <vector>
telsoa015307bc12018-03-09 13:51:08 +000015
16namespace armnn_driver
17{
18
telsoa015307bc12018-03-09 13:51:08 +000019enum class ConversionResult
20{
21 Success,
22 ErrorMappingPools,
23 UnsupportedFeature
24};
25
arovir01b0717b52018-09-05 17:03:25 +010026// A helper template class performing the conversion from an AndroidNN driver Model representation,
telsoa015307bc12018-03-09 13:51:08 +000027// to an armnn::INetwork object
arovir01b0717b52018-09-05 17:03:25 +010028template<typename HalPolicy>
telsoa015307bc12018-03-09 13:51:08 +000029class ModelToINetworkConverter
30{
31public:
arovir01b0717b52018-09-05 17:03:25 +010032 using HalModel = typename HalPolicy::Model;
kevmay01bc5f7842018-08-30 12:34:39 +010033
Nattapat Chaimanowongd5fd9762019-04-04 13:33:10 +010034 ModelToINetworkConverter(const std::vector<armnn::BackendId>& backends,
Matteo Martincighe48bdff2018-09-03 13:50:50 +010035 const HalModel& model,
36 const std::set<unsigned int>& forcedUnsupportedOperations);
telsoa015307bc12018-03-09 13:51:08 +000037
38 ConversionResult GetConversionResult() const { return m_ConversionResult; }
39
40 // Returns the ArmNN INetwork corresponding to the input model, if preparation went smoothly, nullptr otherwise.
arovir01b0717b52018-09-05 17:03:25 +010041 armnn::INetwork* GetINetwork() const { return m_Data.m_Network.get(); }
telsoa015307bc12018-03-09 13:51:08 +000042
43 bool IsOperationSupported(uint32_t operationIndex) const;
44
45private:
46 void Convert();
47
arovir01b0717b52018-09-05 17:03:25 +010048 // Shared aggregate input/output/internal data
49 ConversionData m_Data;
telsoa015307bc12018-03-09 13:51:08 +000050
telsoa015307bc12018-03-09 13:51:08 +000051 // Input data
kevmay01bc5f7842018-08-30 12:34:39 +010052 const HalModel& m_Model;
53 const std::set<unsigned int>& m_ForcedUnsupportedOperations;
telsoa015307bc12018-03-09 13:51:08 +000054
55 // Output data
telsoa01ce3e84a2018-08-31 09:31:35 +010056 ConversionResult m_ConversionResult;
57 std::map<uint32_t, bool> m_OperationSupported;
telsoa015307bc12018-03-09 13:51:08 +000058};
59
Matteo Martincigh79250ab2018-09-04 16:28:10 +010060} // armnn_driver