blob: b8db04ccf2ae541c9e10a25f0a4b95d6f074059a [file] [log] [blame]
Sadik Armagana2747482021-02-09 10:28:54 +00001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <tensorflow/lite/builtin_ops.h>
9#include <tensorflow/lite/c/builtin_op_data.h>
10#include <tensorflow/lite/c/common.h>
11#include <tensorflow/lite/kernels/internal/tensor_ctypes.h>
12#include <tensorflow/lite/minimal_logging.h>
13
14namespace armnnDelegate
15{
16
17TfLiteStatus VisitReduceOperator(DelegateData& delegateData,
18 TfLiteContext* tfLiteContext,
19 TfLiteNode* tfLiteNode,
20 int nodeIndex,
21 int32_t reduceOperatorCode)
22{
23 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
24 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25
26 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
27 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
28 if (!IsValid(tfLiteContext, tfLiteInputTensor, reduceOperatorCode, nodeIndex))
29 {
30 return kTfLiteError;
31 }
32
33 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
34 if (!IsValid(tfLiteContext, tfLiteOutputTensor, reduceOperatorCode, nodeIndex))
35 {
36 return kTfLiteError;
37 }
38
39 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
40 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
41
42 // Get const axis value from model and set it to descriptor.
43 const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
44 if (!IsValid(tfLiteContext, tfLiteAxisTensor, reduceOperatorCode, nodeIndex))
45 {
46 return kTfLiteError;
47 }
48
49 const armnn::TensorInfo& axisTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteAxisTensor);
50 auto* axisTensorData = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
51
52 std::vector<int32_t> axis;
53 // Add axis data to vector to be converter to unsigned int and assigned to descriptor axis.
54 if (axisTensorData != nullptr)
55 {
56 for (unsigned int i = 0; i < axisTensorInfo.GetNumElements(); ++i)
57 {
58 axis.emplace_back(axisTensorData[i]);
59 }
60 }
61 else
62 {
63 for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); ++i)
64 {
65 axis.push_back(i);
66 }
67 }
68
69 // Convert the axis to unsigned int and remove duplicates.
70 unsigned int rank = inputTensorInfo.GetNumDimensions();
71 std::set<unsigned int> uniqueAxis;
72 std::transform(axis.begin(),
73 axis.end(),
74 std::inserter(uniqueAxis, uniqueAxis.begin()),
75 [rank](int i)->unsigned int{ return (i + rank) % rank; });
76
77 armnn::ReduceDescriptor desc;
78 desc.m_vAxis.assign(uniqueAxis.begin(), uniqueAxis.end());
79
80 auto* reducerParameters = reinterpret_cast<TfLiteReducerParams*>(tfLiteNode->builtin_data);
81 desc.m_KeepDims = reducerParameters->keep_dims;
82 if (reduceOperatorCode == kTfLiteBuiltinReduceMax)
83 {
84 desc.m_ReduceOperation = armnn::ReduceOperation::Max;
85 }
86 else if (reduceOperatorCode == kTfLiteBuiltinReduceMin)
87 {
88 desc.m_ReduceOperation = armnn::ReduceOperation::Min;
89 }
90 else if (reduceOperatorCode == kTfLiteBuiltinSum)
91 {
92 desc.m_ReduceOperation = armnn::ReduceOperation::Sum;
93 }
Teresa Charlin4e3e8312021-08-05 12:34:37 +010094 else if (reduceOperatorCode == kTfLiteBuiltinReduceProd)
95 {
96 desc.m_ReduceOperation = armnn::ReduceOperation::Prod;
97 }
Sadik Armagana2747482021-02-09 10:28:54 +000098 else
99 {
100 TF_LITE_MAYBE_KERNEL_LOG(
101 tfLiteContext,
102 "TfLiteArmnnDelegate: Unsupported Reduction Operator #%d node #%d: ",
103 reduceOperatorCode, nodeIndex);
104 return kTfLiteError;
105 }
106
107 bool isSupported = false;
108 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
109 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000110 FORWARD_LAYER_SUPPORT_FUNC("REDUCE",
Sadik Armagana2747482021-02-09 10:28:54 +0000111 tfLiteContext,
112 IsReduceSupported,
113 delegateData.m_Backends,
114 isSupported,
115 inputTensorInfo,
116 outInfo,
117 desc);
118 };
119
120 if (!delegateData.m_Network)
121 {
122 validateFunc(outputTensorInfo, isSupported);
123 return isSupported ? kTfLiteOk : kTfLiteError;
124 }
125
126 // Add an Reduce layer
127 armnn::IConnectableLayer* layer = delegateData.m_Network->AddReduceLayer(desc);
128 ARMNN_ASSERT(layer != nullptr);
129
130 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
131 outputSlot.SetTensorInfo(outputTensorInfo);
132
133 // Connect
134 return Connect(layer, tfLiteNode, delegateData);
135}
136
137} // namespace armnnDelegate