IVGCVSW-7608 IVGCVSW-7594 IVGCVSW-7598 IVGCVSW-7599 Implement Floor,
Lstm, Pooling2d and Pooling3d operators for Opaque Delegate
Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com>
Change-Id: Ic9af1c50589285ab359661699d32a889cd267cd9
diff --git a/delegate/opaque/src/armnn_delegate.cpp b/delegate/opaque/src/armnn_delegate.cpp
index 3b647f3..ead577f 100644
--- a/delegate/opaque/src/armnn_delegate.cpp
+++ b/delegate/opaque/src/armnn_delegate.cpp
@@ -641,6 +641,12 @@
tfLiteNode,
nodeIndex,
kTfLiteBuiltinArgMin);
+ case kTfLiteBuiltinAveragePool2d:
+ return VisitPooling2dOperator(delegateData,
+ tfLiteContext,
+ tfLiteNode,
+ nodeIndex,
+ kTfLiteBuiltinAveragePool2d);
case kTfLiteBuiltinBatchMatmul:
return VisitBatchMatMulOperator(delegateData,
tfLiteContext,
@@ -684,6 +690,30 @@
tfLiteNode,
nodeIndex,
kTfLiteBuiltinConv3d);
+ case kTfLiteBuiltinCustom:
+ {
+ // Custom operators are defined by the name rather than the builtin code.
+ // Parse the custom_name param in the registration to point to the correct visitor function.
+ std::string customOperatorName = TfLiteRegistrationExternalGetCustomName(tfLiteRegistration);
+ if ( customOperatorName == "AveragePool3D" )
+ {
+ return VisitPooling3dOperator(delegateData,
+ tfLiteContext,
+ tfLiteNode,
+ nodeIndex,
+ customOperatorName);
+ }
+ else if (customOperatorName == "MaxPool3D")
+ {
+ return VisitPooling3dOperator(delegateData,
+ tfLiteContext,
+ tfLiteNode,
+ nodeIndex,
+ customOperatorName);
+ }
+ // Invalid or unsupported custom operator
+ return kTfLiteError;
+ }
case kTfLiteBuiltinDepthwiseConv2d:
return VisitConvolutionOperator(delegateData,
tfLiteContext,
@@ -710,6 +740,12 @@
nodeIndex,
kTfLiteBuiltinExp,
armnn::UnaryOperation::Exp);
+ case kTfLiteBuiltinFloor:
+ return VisitFloorOperator(delegateData,
+ tfLiteContext,
+ tfLiteNode,
+ nodeIndex,
+ kTfLiteBuiltinFloor);
case kTfLiteBuiltinFullyConnected:
return VisitFullyConnectedOperator(delegateData,
tfLiteContext,
@@ -754,6 +790,12 @@
tfLiteNode,
nodeIndex,
kTfLiteBuiltinL2Normalization);
+ case kTfLiteBuiltinL2Pool2d:
+ return VisitPooling2dOperator(delegateData,
+ tfLiteContext,
+ tfLiteNode,
+ nodeIndex,
+ kTfLiteBuiltinL2Pool2d);
case kTfLiteBuiltinLess:
return VisitComparisonOperator(delegateData,
tfLiteContext,
@@ -808,6 +850,18 @@
nodeIndex,
kTfLiteBuiltinLogicalOr,
armnn::LogicalBinaryOperation::LogicalOr);
+ case kTfLiteBuiltinLstm:
+ return VisitLstmOperator(delegateData,
+ tfLiteContext,
+ tfLiteNode,
+ nodeIndex,
+ kTfLiteBuiltinLstm);
+ case kTfLiteBuiltinMaxPool2d:
+ return VisitPooling2dOperator(delegateData,
+ tfLiteContext,
+ tfLiteNode,
+ nodeIndex,
+ kTfLiteBuiltinMaxPool2d);
case kTfLiteBuiltinMean:
return VisitControlOperator(delegateData,
tfLiteContext,