blob: 4095ac4ac2e505e75abab28e465214e9767454ba [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armagan32ca1442020-11-13 17:51:56 +00008#include "DelegateUtils.hpp"
9
Sadik Armagan62483be2020-10-23 17:14:43 +010010#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14
15namespace armnnDelegate
16{
17
18TfLiteStatus VisitPoolingOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 TfLiteNode* tfLiteNode,
21 int nodeIndex,
Narumol Prangnawarat50c87d32020-11-09 18:42:11 +000022 int32_t tfLitePoolingOperatorCode)
Sadik Armagan62483be2020-10-23 17:14:43 +010023{
Narumol Prangnawarat50c87d32020-11-09 18:42:11 +000024 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
26
27 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
28 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
29 if (IsDynamicTensor(tfLiteInputTensor))
30 {
31 TF_LITE_MAYBE_KERNEL_LOG(
32 tfLiteContext,
33 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
34 tfLitePoolingOperatorCode, nodeIndex);
35 return kTfLiteError;
36 }
37
38 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
39 if (IsDynamicTensor(tfLiteOutputTensor))
40 {
41 TF_LITE_MAYBE_KERNEL_LOG(
42 tfLiteContext,
43 "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
44 tfLitePoolingOperatorCode, nodeIndex);
45 return kTfLiteError;
46 }
47
48 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
49 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
50
51 armnn::PoolingAlgorithm poolingAlgorithm;
52 switch(tfLitePoolingOperatorCode)
53 {
Narumol Prangnawarat80815362020-11-11 11:33:03 +000054 case kTfLiteBuiltinAveragePool2d:
55 poolingAlgorithm = armnn::PoolingAlgorithm::Average;
56 break;
57 case kTfLiteBuiltinL2Pool2d:
58 poolingAlgorithm = armnn::PoolingAlgorithm::L2;
59 break;
Narumol Prangnawarat50c87d32020-11-09 18:42:11 +000060 case kTfLiteBuiltinMaxPool2d:
61 poolingAlgorithm = armnn::PoolingAlgorithm::Max;
62 break;
63 default:
64 return kTfLiteError;
65 }
66
67 armnn::Pooling2dDescriptor descriptor;
68 descriptor.m_PoolType = poolingAlgorithm;
69
70 auto* params = reinterpret_cast<TfLitePoolParams*>(tfLiteNode->builtin_data);
71 descriptor.m_PoolWidth = params->filter_width;
72 descriptor.m_PoolHeight = params->filter_height;
73 descriptor.m_StrideX = params->stride_width;
74 descriptor.m_StrideY = params->stride_height;
75 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
76
77 unsigned int inputHeight = inputTensorInfo.GetShape()[1];
78 unsigned int inputWidth = inputTensorInfo.GetShape()[2];
79
80 CalcPadding(inputHeight, descriptor.m_PoolHeight, descriptor.m_StrideY, 1u,
81 descriptor.m_PadTop, descriptor.m_PadBottom, params->padding);
82 CalcPadding(inputWidth, descriptor.m_PoolWidth, descriptor.m_StrideX, 1u,
83 descriptor.m_PadLeft, descriptor.m_PadRight, params->padding);
84
85 bool isSupported = false;
86 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
87 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +000088 FORWARD_LAYER_SUPPORT_FUNC("POOLING_2D",
Narumol Prangnawarat50c87d32020-11-09 18:42:11 +000089 tfLiteContext,
90 IsPooling2dSupported,
91 delegateData.m_Backends,
92 isSupported,
93 inputTensorInfo,
94 outputTensorInfo,
95 descriptor);
96 };
97
98 if (!delegateData.m_Network)
99 {
100 validateFunc(outputTensorInfo, isSupported);
101 return isSupported ? kTfLiteOk : kTfLiteError;
102 }
103
104 armnn::IConnectableLayer* poolingLayer = delegateData.m_Network->AddPooling2dLayer(descriptor);
105 ARMNN_ASSERT(poolingLayer != nullptr);
106
107 armnn::IOutputSlot& outputSlot = poolingLayer->GetOutputSlot(0);
108 outputSlot.SetTensorInfo(outputTensorInfo);
109 Connect(poolingLayer, tfLiteNode, delegateData);
110
111 // Check activation
112 TfLiteFusedActivation activationType = params->activation;
113 return FusedActivation(tfLiteContext, tfLiteNode, activationType, poolingLayer, 0, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100114}
115
116} // namespace armnnDelegate