blob: 2d8b462cd255731955888fdbdd071c210092ef5a [file] [log] [blame]
Sadik Armagana2747482021-02-09 10:28:54 +00001//
Ryan OShea4c231de2023-01-17 15:19:20 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagana2747482021-02-09 10:28:54 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <tensorflow/lite/builtin_ops.h>
9#include <tensorflow/lite/c/builtin_op_data.h>
10#include <tensorflow/lite/c/common.h>
11#include <tensorflow/lite/kernels/internal/tensor_ctypes.h>
12#include <tensorflow/lite/minimal_logging.h>
13
14namespace armnnDelegate
15{
16
17TfLiteStatus VisitReduceOperator(DelegateData& delegateData,
18 TfLiteContext* tfLiteContext,
19 TfLiteNode* tfLiteNode,
20 int nodeIndex,
21 int32_t reduceOperatorCode)
22{
23 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
24 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25
26 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
27 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
28 if (!IsValid(tfLiteContext, tfLiteInputTensor, reduceOperatorCode, nodeIndex))
29 {
30 return kTfLiteError;
31 }
32
33 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
34 if (!IsValid(tfLiteContext, tfLiteOutputTensor, reduceOperatorCode, nodeIndex))
35 {
36 return kTfLiteError;
37 }
38
39 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010040 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Sadik Armagana2747482021-02-09 10:28:54 +000041
42 // Get const axis value from model and set it to descriptor.
43 const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
44 if (!IsValid(tfLiteContext, tfLiteAxisTensor, reduceOperatorCode, nodeIndex))
45 {
46 return kTfLiteError;
47 }
48
49 const armnn::TensorInfo& axisTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteAxisTensor);
50 auto* axisTensorData = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
51
52 std::vector<int32_t> axis;
53 // Add axis data to vector to be converter to unsigned int and assigned to descriptor axis.
54 if (axisTensorData != nullptr)
55 {
56 for (unsigned int i = 0; i < axisTensorInfo.GetNumElements(); ++i)
57 {
58 axis.emplace_back(axisTensorData[i]);
59 }
60 }
61 else
62 {
63 for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); ++i)
64 {
65 axis.push_back(i);
66 }
67 }
68
69 // Convert the axis to unsigned int and remove duplicates.
70 unsigned int rank = inputTensorInfo.GetNumDimensions();
71 std::set<unsigned int> uniqueAxis;
72 std::transform(axis.begin(),
73 axis.end(),
74 std::inserter(uniqueAxis, uniqueAxis.begin()),
75 [rank](int i)->unsigned int{ return (i + rank) % rank; });
76
77 armnn::ReduceDescriptor desc;
78 desc.m_vAxis.assign(uniqueAxis.begin(), uniqueAxis.end());
79
80 auto* reducerParameters = reinterpret_cast<TfLiteReducerParams*>(tfLiteNode->builtin_data);
81 desc.m_KeepDims = reducerParameters->keep_dims;
82 if (reduceOperatorCode == kTfLiteBuiltinReduceMax)
83 {
84 desc.m_ReduceOperation = armnn::ReduceOperation::Max;
85 }
86 else if (reduceOperatorCode == kTfLiteBuiltinReduceMin)
87 {
88 desc.m_ReduceOperation = armnn::ReduceOperation::Min;
89 }
90 else if (reduceOperatorCode == kTfLiteBuiltinSum)
91 {
92 desc.m_ReduceOperation = armnn::ReduceOperation::Sum;
93 }
Teresa Charlin4e3e8312021-08-05 12:34:37 +010094 else if (reduceOperatorCode == kTfLiteBuiltinReduceProd)
95 {
96 desc.m_ReduceOperation = armnn::ReduceOperation::Prod;
97 }
Sadik Armagana2747482021-02-09 10:28:54 +000098 else
99 {
100 TF_LITE_MAYBE_KERNEL_LOG(
101 tfLiteContext,
102 "TfLiteArmnnDelegate: Unsupported Reduction Operator #%d node #%d: ",
103 reduceOperatorCode, nodeIndex);
104 return kTfLiteError;
105 }
106
107 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +0100108 armnn::BackendId setBackend;
Sadik Armagana2747482021-02-09 10:28:54 +0000109 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
110 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000111 FORWARD_LAYER_SUPPORT_FUNC("REDUCE",
Sadik Armagana2747482021-02-09 10:28:54 +0000112 tfLiteContext,
113 IsReduceSupported,
114 delegateData.m_Backends,
115 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100116 setBackend,
Sadik Armagana2747482021-02-09 10:28:54 +0000117 inputTensorInfo,
118 outInfo,
119 desc);
120 };
121
122 if (!delegateData.m_Network)
123 {
124 validateFunc(outputTensorInfo, isSupported);
125 return isSupported ? kTfLiteOk : kTfLiteError;
126 }
127
128 // Add an Reduce layer
129 armnn::IConnectableLayer* layer = delegateData.m_Network->AddReduceLayer(desc);
Cathal Corbett53837672022-09-01 11:34:37 +0100130 layer->SetBackendId(setBackend);
Sadik Armagana2747482021-02-09 10:28:54 +0000131 ARMNN_ASSERT(layer != nullptr);
132
133 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
134 outputSlot.SetTensorInfo(outputTensorInfo);
135
Ryan OShea4c231de2023-01-17 15:19:20 +0000136 // try to connect the Constant Inputs if there are any
137 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
138 {
139 return kTfLiteError;
140 }
141
Sadik Armagana2747482021-02-09 10:28:54 +0000142 // Connect
143 return Connect(layer, tfLiteNode, delegateData);
144}
145
146} // namespace armnnDelegate