blob: 447a4da9ab6383b4b7c6146415bc46f0f0e66b24 [file] [log] [blame]
Kevin May8ab2d7a2021-05-07 09:32:51 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/utility/IgnoreUnused.hpp>
9
10#include "DelegateUtils.hpp"
11
12#include <tensorflow/lite/builtin_ops.h>
13#include <tensorflow/lite/c/builtin_op_data.h>
14#include <tensorflow/lite/c/common.h>
15#include <tensorflow/lite/minimal_logging.h>
16#include <numeric>
17
18namespace armnnDelegate
19{
20
21TfLiteStatus VisitUnpackOperator(DelegateData& delegateData,
22 TfLiteContext* tfLiteContext,
23 TfLiteNode* tfLiteNode,
24 int nodeIndex,
25 int32_t operatorCode)
26{
27 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
28
29 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
30 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
31
32 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
33 {
34 return kTfLiteError;
35 }
36
37 // Get Unpack Axis
38 const auto params = reinterpret_cast<TfLiteUnpackParams*>(tfLiteNode->builtin_data);
39
40 const unsigned int unpackAxis = NonNegative(params->axis, nodeIndex);
41
42 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
43
44 if (unpackAxis >= inputTensorInfo.GetNumDimensions())
45 {
46 TF_LITE_MAYBE_KERNEL_LOG(
47 tfLiteContext,
48 "TfLiteArmnnDelegate: The unpack axis #%d cannot be greater than or equal to "
49 "the number of input dimensions #%d in operator #%d node #%d",
50 unpackAxis, inputTensorInfo.GetNumDimensions(), operatorCode, nodeIndex);
51 return kTfLiteError;
52 }
53
54 // Get Unpack Num
55 unsigned int unpackNum = NonNegative(params->num, nodeIndex);
56
57 // If num is not defined, automatically infer from the length of the dimension axis.
58 if(unpackNum == 0)
59 {
60 unpackNum = inputTensorInfo.GetShape()[unpackAxis];
61 }
62
63 // If unpack number cannot be inferred and is still zero, return kTfLiteError.
64 if(unpackNum == 0)
65 {
66 TF_LITE_MAYBE_KERNEL_LOG(
67 tfLiteContext,
68 "TfLiteArmnnDelegate: Number to unpack must greater than zero in operator #%d node #%d: ",
69 operatorCode, nodeIndex);
70 return kTfLiteError;
71 }
72
73 // Check outputs
74 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, unpackNum, nodeIndex));
75
76
77 auto inputDimSize = inputTensorInfo.GetNumDimensions();
78 std::vector<unsigned int> unpackDimSizes(inputDimSize);
79
80 // Add current input shape to unpackDimSizes
81 for (unsigned int i = 0; i < inputDimSize; ++i)
82 {
83 unpackDimSizes[i] = inputTensorInfo.GetShape()[i];
84 }
85
86 if (unpackDimSizes[unpackAxis] != unpackNum)
87 {
88 TF_LITE_MAYBE_KERNEL_LOG(
89 tfLiteContext,
90 "TfLiteArmnnDelegate: Number to unpack must be the same as length "
91 "of the dimension to unpack along in operator #%d node #%d: ",
92 operatorCode, nodeIndex);
93 return kTfLiteError;
94 }
95
96 unpackDimSizes[unpackAxis] /= unpackNum;
97
98 armnn::SplitterDescriptor splitDesc(unpackNum, static_cast<unsigned int>(unpackDimSizes.size()));
99 for (unsigned int j = 0; j < unpackNum; ++j)
100 {
101 // Set the size of the views.
102 for (unsigned int dimIdx = 0; dimIdx < unpackDimSizes.size(); ++dimIdx)
103 {
104 splitDesc.SetViewSize(j, dimIdx, unpackDimSizes[dimIdx]);
105 }
106 splitDesc.SetViewOriginCoord(j, unpackAxis, unpackDimSizes[unpackAxis] * j);
107 }
108
109 std::vector<armnn::TensorInfo> outputs;
110 for (unsigned int i = 0; i < unpackNum; ++i)
111 {
112 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
113 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
114 {
115 return kTfLiteError;
116 }
117 outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor));
118 }
119 const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
120
Kevin May4cad8602021-05-18 09:57:43 +0100121 // Determine the shape of the Splitter layer outputs for validation
122 armnn::TensorShape splitOutShape = armnn::TensorShape(static_cast<unsigned int>(unpackDimSizes.size()),
123 unpackDimSizes.data());
124
125 std::vector<armnn::TensorInfo> splitterOutputs;
126 for (unsigned int outputIndex = 0; outputIndex < outputTensorInfos.size(); ++outputIndex)
127 {
128 splitterOutputs.push_back(armnn::TensorInfo(splitOutShape,
129 outputTensorInfos[outputIndex].get().GetDataType(),
130 outputTensorInfos[outputIndex].get().GetQuantizationScale(),
131 outputTensorInfos[outputIndex].get().GetQuantizationOffset()));
132 }
133 std::vector<std::reference_wrapper<armnn::TensorInfo>> splitterOutputTensorInfos(splitterOutputs.begin(),
134 splitterOutputs.end());
135
Kevin May8ab2d7a2021-05-07 09:32:51 +0100136 if (!delegateData.m_Network)
137 {
Kevin May4cad8602021-05-18 09:57:43 +0100138 // Check if splitter is supported
Kevin May8ab2d7a2021-05-07 09:32:51 +0100139 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000140 FORWARD_LAYER_SUPPORT_FUNC("UNPACK",
Kevin May8ab2d7a2021-05-07 09:32:51 +0100141 tfLiteContext,
142 IsSplitterSupported,
143 delegateData.m_Backends,
144 isSupported,
145 inputTensorInfo,
Kevin May4cad8602021-05-18 09:57:43 +0100146 splitterOutputTensorInfos,
Kevin May8ab2d7a2021-05-07 09:32:51 +0100147 splitDesc);
148 return isSupported ? kTfLiteOk : kTfLiteError;
149 }
150
Kevin May4cad8602021-05-18 09:57:43 +0100151 // Create Reshape descriptor from the first outputTensorInfo to validate a single Reshape layer
152 // Use this descriptor later when creating every ReshapeLayer as all Reshape Layers should be the same
153 armnn::ReshapeDescriptor reshapeDescriptor;
154 reshapeDescriptor.m_TargetShape = outputTensorInfos[0].get().GetShape();
155
156 if (!delegateData.m_Network)
157 {
158 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000159 FORWARD_LAYER_SUPPORT_FUNC("RESHAPE",
Kevin May4cad8602021-05-18 09:57:43 +0100160 tfLiteContext,
161 IsReshapeSupported,
162 delegateData.m_Backends,
163 isSupported,
164 splitterOutputTensorInfos[0],
165 outputTensorInfos[0],
166 reshapeDescriptor);
167 return isSupported ? kTfLiteOk : kTfLiteError;
168 };
169
Kevin May8ab2d7a2021-05-07 09:32:51 +0100170 std::string splitterLayerName("Unpack Splitter");
171
172 armnn::IConnectableLayer* splitterLayer = delegateData.m_Network->AddSplitterLayer(splitDesc,
173 splitterLayerName.c_str());
174 ARMNN_ASSERT(splitterLayer != nullptr);
175
176 for (unsigned int k = 0; k < splitterLayer->GetNumOutputSlots(); ++k)
177 {
178 splitterLayer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
179 }
180
181 // Connect the input slots
182 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(splitterLayer->GetInputSlot(0));
183
Kevin May8ab2d7a2021-05-07 09:32:51 +0100184 // Create reshape to remove the unpacked dimension for unpack operator of each output from Splitter.
185 for (unsigned int outputIndex = 0; outputIndex < splitterLayer->GetNumOutputSlots(); ++outputIndex)
186 {
Kevin May8ab2d7a2021-05-07 09:32:51 +0100187 std::string reshapeLayerName("Unpack Reshape");
Kevin May8ab2d7a2021-05-07 09:32:51 +0100188 armnn::IConnectableLayer* reshapeLayer = delegateData.m_Network->AddReshapeLayer(reshapeDescriptor,
189 reshapeLayerName.c_str());
Kevin May8ab2d7a2021-05-07 09:32:51 +0100190 ARMNN_ASSERT(reshapeLayer != nullptr);
191
Kevin May4cad8602021-05-18 09:57:43 +0100192 splitterLayer->GetOutputSlot(outputIndex).SetTensorInfo(splitterOutputTensorInfos[outputIndex]);
Kevin May8ab2d7a2021-05-07 09:32:51 +0100193 splitterLayer->GetOutputSlot(outputIndex).Connect(reshapeLayer->GetInputSlot(0));
194
Kevin May4cad8602021-05-18 09:57:43 +0100195 armnn::TensorInfo outputTensorInfo = outputTensorInfos[outputIndex];
Kevin May8ab2d7a2021-05-07 09:32:51 +0100196 reshapeLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
197
198 armnn::IOutputSlot& slot = reshapeLayer->GetOutputSlot(0);
199
200 delegateData.m_OutputSlotForNode[
201 static_cast<unsigned long>(tfLiteNode->outputs->data[outputIndex])] = &slot;
202
203 }
204
205 return kTfLiteOk;
206}
207
208} // namespace armnnDelegate