blob: a3758fd532348e67c3b59d7c60cf1d8e6df3a98c [file] [log] [blame]
telsoa015307bc12018-03-09 13:51:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
telsoa015307bc12018-03-09 13:51:08 +00004//
5
6#pragma once
7
surmeh01deb3bdb2018-07-05 12:06:04 +01008#include "ArmnnDriver.hpp"
arovir01b0717b52018-09-05 17:03:25 +01009#include "ConversionUtils.hpp"
telsoa01ce3e84a2018-08-31 09:31:35 +010010
telsoa015307bc12018-03-09 13:51:08 +000011#include <armnn/ArmNN.hpp>
telsoa015307bc12018-03-09 13:51:08 +000012
telsoa015307bc12018-03-09 13:51:08 +000013#include <set>
14
15namespace armnn_driver
16{
17
telsoa015307bc12018-03-09 13:51:08 +000018enum class ConversionResult
19{
20 Success,
21 ErrorMappingPools,
22 UnsupportedFeature
23};
24
arovir01b0717b52018-09-05 17:03:25 +010025// A helper template class performing the conversion from an AndroidNN driver Model representation,
telsoa015307bc12018-03-09 13:51:08 +000026// to an armnn::INetwork object
arovir01b0717b52018-09-05 17:03:25 +010027template<typename HalPolicy>
telsoa015307bc12018-03-09 13:51:08 +000028class ModelToINetworkConverter
29{
30public:
arovir01b0717b52018-09-05 17:03:25 +010031 using HalModel = typename HalPolicy::Model;
kevmay01bc5f7842018-08-30 12:34:39 +010032
telsoa01ce3e84a2018-08-31 09:31:35 +010033 ModelToINetworkConverter(armnn::Compute compute,
Matteo Martincighe48bdff2018-09-03 13:50:50 +010034 const HalModel& model,
35 const std::set<unsigned int>& forcedUnsupportedOperations);
telsoa015307bc12018-03-09 13:51:08 +000036
37 ConversionResult GetConversionResult() const { return m_ConversionResult; }
38
39 // Returns the ArmNN INetwork corresponding to the input model, if preparation went smoothly, nullptr otherwise.
arovir01b0717b52018-09-05 17:03:25 +010040 armnn::INetwork* GetINetwork() const { return m_Data.m_Network.get(); }
telsoa015307bc12018-03-09 13:51:08 +000041
42 bool IsOperationSupported(uint32_t operationIndex) const;
43
44private:
45 void Convert();
46
arovir01b0717b52018-09-05 17:03:25 +010047 // Shared aggregate input/output/internal data
48 ConversionData m_Data;
telsoa015307bc12018-03-09 13:51:08 +000049
telsoa015307bc12018-03-09 13:51:08 +000050 // Input data
kevmay01bc5f7842018-08-30 12:34:39 +010051 const HalModel& m_Model;
52 const std::set<unsigned int>& m_ForcedUnsupportedOperations;
telsoa015307bc12018-03-09 13:51:08 +000053
54 // Output data
telsoa01ce3e84a2018-08-31 09:31:35 +010055 ConversionResult m_ConversionResult;
56 std::map<uint32_t, bool> m_OperationSupported;
telsoa015307bc12018-03-09 13:51:08 +000057};
58
Matteo Martincigh79250ab2018-09-04 16:28:10 +010059} // armnn_driver