blob: 13a11d3e612487e91b2e8a4b9e86762567df42fb [file] [log] [blame]
Sadik Armagana2747482021-02-09 10:28:54 +00001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <tensorflow/lite/builtin_ops.h>
9#include <tensorflow/lite/c/builtin_op_data.h>
10#include <tensorflow/lite/c/common.h>
11#include <tensorflow/lite/kernels/internal/tensor_ctypes.h>
12#include <tensorflow/lite/minimal_logging.h>
13
14namespace armnnDelegate
15{
16
17TfLiteStatus VisitReduceOperator(DelegateData& delegateData,
18 TfLiteContext* tfLiteContext,
19 TfLiteNode* tfLiteNode,
20 int nodeIndex,
21 int32_t reduceOperatorCode)
22{
23 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
24 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25
26 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
27 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
28 if (!IsValid(tfLiteContext, tfLiteInputTensor, reduceOperatorCode, nodeIndex))
29 {
30 return kTfLiteError;
31 }
32
33 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
34 if (!IsValid(tfLiteContext, tfLiteOutputTensor, reduceOperatorCode, nodeIndex))
35 {
36 return kTfLiteError;
37 }
38
39 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
40 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
41
42 // Get const axis value from model and set it to descriptor.
43 const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
44 if (!IsValid(tfLiteContext, tfLiteAxisTensor, reduceOperatorCode, nodeIndex))
45 {
46 return kTfLiteError;
47 }
48
49 const armnn::TensorInfo& axisTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteAxisTensor);
50 auto* axisTensorData = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
51
52 std::vector<int32_t> axis;
53 // Add axis data to vector to be converter to unsigned int and assigned to descriptor axis.
54 if (axisTensorData != nullptr)
55 {
56 for (unsigned int i = 0; i < axisTensorInfo.GetNumElements(); ++i)
57 {
58 axis.emplace_back(axisTensorData[i]);
59 }
60 }
61 else
62 {
63 for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); ++i)
64 {
65 axis.push_back(i);
66 }
67 }
68
69 // Convert the axis to unsigned int and remove duplicates.
70 unsigned int rank = inputTensorInfo.GetNumDimensions();
71 std::set<unsigned int> uniqueAxis;
72 std::transform(axis.begin(),
73 axis.end(),
74 std::inserter(uniqueAxis, uniqueAxis.begin()),
75 [rank](int i)->unsigned int{ return (i + rank) % rank; });
76
77 armnn::ReduceDescriptor desc;
78 desc.m_vAxis.assign(uniqueAxis.begin(), uniqueAxis.end());
79
80 auto* reducerParameters = reinterpret_cast<TfLiteReducerParams*>(tfLiteNode->builtin_data);
81 desc.m_KeepDims = reducerParameters->keep_dims;
82 if (reduceOperatorCode == kTfLiteBuiltinReduceMax)
83 {
84 desc.m_ReduceOperation = armnn::ReduceOperation::Max;
85 }
86 else if (reduceOperatorCode == kTfLiteBuiltinReduceMin)
87 {
88 desc.m_ReduceOperation = armnn::ReduceOperation::Min;
89 }
90 else if (reduceOperatorCode == kTfLiteBuiltinSum)
91 {
92 desc.m_ReduceOperation = armnn::ReduceOperation::Sum;
93 }
94 else
95 {
96 TF_LITE_MAYBE_KERNEL_LOG(
97 tfLiteContext,
98 "TfLiteArmnnDelegate: Unsupported Reduction Operator #%d node #%d: ",
99 reduceOperatorCode, nodeIndex);
100 return kTfLiteError;
101 }
102
103 bool isSupported = false;
104 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
105 {
106 FORWARD_LAYER_SUPPORT_FUNC(__func__,
107 tfLiteContext,
108 IsReduceSupported,
109 delegateData.m_Backends,
110 isSupported,
111 inputTensorInfo,
112 outInfo,
113 desc);
114 };
115
116 if (!delegateData.m_Network)
117 {
118 validateFunc(outputTensorInfo, isSupported);
119 return isSupported ? kTfLiteOk : kTfLiteError;
120 }
121
122 // Add an Reduce layer
123 armnn::IConnectableLayer* layer = delegateData.m_Network->AddReduceLayer(desc);
124 ARMNN_ASSERT(layer != nullptr);
125
126 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
127 outputSlot.SetTensorInfo(outputTensorInfo);
128
129 // Connect
130 return Connect(layer, tfLiteNode, delegateData);
131}
132
133} // namespace armnnDelegate