IVGCVSW-1770: Refactor ModelToINetworkConverter to allow conversion of HAL1.1 operators
Change-Id: I9b10f0a9c88344df108b2325c0233f9660fa6b7c
diff --git a/ModelToINetworkConverter.hpp b/ModelToINetworkConverter.hpp
index f0e2897..6fdcf6b 100644
--- a/ModelToINetworkConverter.hpp
+++ b/ModelToINetworkConverter.hpp
@@ -33,13 +33,28 @@
UnsupportedFeature
};
+struct HalVersion_1_0
+{
+ using Model = ::android::hardware::neuralnetworks::V1_0::Model;
+};
+
+#if defined(ARMNN_ANDROID_NN_V1_1)
+struct HalVersion_1_1
+{
+ using Model = ::android::hardware::neuralnetworks::V1_1::Model;
+};
+#endif
+
// A helper performing the conversion from an AndroidNN driver Model representation,
// to an armnn::INetwork object
+template <typename HalVersion>
class ModelToINetworkConverter
{
public:
+ using HalModel = typename HalVersion::Model;
+
ModelToINetworkConverter(armnn::Compute compute,
- const ::android::hardware::neuralnetworks::V1_0::Model& model,
+ const HalModel& model,
const std::set<unsigned int>& forcedUnsupportedOperations);
ConversionResult GetConversionResult() const { return m_ConversionResult; }
@@ -52,6 +67,10 @@
private:
void Convert();
+#if defined(ARMNN_ANDROID_NN_V1_1)
+ bool ConvertOperation(const ::android::hardware::neuralnetworks::V1_1::Operation& operation);
+#endif
+
bool ConvertOperation(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
bool ConvertAdd(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
@@ -112,7 +131,7 @@
const Operand* GetOutputOperand(const ::android::hardware::neuralnetworks::V1_0::Operation& operation,
uint32_t outputIndex) const;
- template<typename T>
+ template <typename T>
bool GetInputScalar(const ::android::hardware::neuralnetworks::V1_0::Operation& operation, uint32_t inputIndex,
OperandType type, T& outValue) const;
@@ -172,9 +191,9 @@
// Input data
- armnn::Compute m_Compute;
- const ::android::hardware::neuralnetworks::V1_0::Model& m_Model;
- const std::set<unsigned int>& m_ForcedUnsupportedOperations;
+ armnn::Compute m_Compute;
+ const HalModel& m_Model;
+ const std::set<unsigned int>& m_ForcedUnsupportedOperations;
// Output data
armnn::INetworkPtr m_Network;