blob: b3336ec990dab27966b30074ab8f687691cc2168 [file] [log] [blame]
Kevin May8ab2d7a2021-05-07 09:32:51 +01001//
Teresa Charlinad1b3d72023-03-14 12:10:28 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Kevin May8ab2d7a2021-05-07 09:32:51 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/utility/IgnoreUnused.hpp>
9
Matthew Sloyan11572322023-03-16 10:17:51 +000010#include <ClassicDelegateUtils.hpp>
Kevin May8ab2d7a2021-05-07 09:32:51 +010011
12#include <tensorflow/lite/builtin_ops.h>
13#include <tensorflow/lite/c/builtin_op_data.h>
14#include <tensorflow/lite/c/common.h>
15#include <tensorflow/lite/minimal_logging.h>
16#include <numeric>
17
18namespace armnnDelegate
19{
20
21TfLiteStatus VisitUnpackOperator(DelegateData& delegateData,
22 TfLiteContext* tfLiteContext,
23 TfLiteNode* tfLiteNode,
24 int nodeIndex,
25 int32_t operatorCode)
26{
27 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
28
29 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
30 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
31
32 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
33 {
34 return kTfLiteError;
35 }
36
37 // Get Unpack Axis
38 const auto params = reinterpret_cast<TfLiteUnpackParams*>(tfLiteNode->builtin_data);
39
40 const unsigned int unpackAxis = NonNegative(params->axis, nodeIndex);
41
42 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
43
44 if (unpackAxis >= inputTensorInfo.GetNumDimensions())
45 {
46 TF_LITE_MAYBE_KERNEL_LOG(
47 tfLiteContext,
48 "TfLiteArmnnDelegate: The unpack axis #%d cannot be greater than or equal to "
49 "the number of input dimensions #%d in operator #%d node #%d",
50 unpackAxis, inputTensorInfo.GetNumDimensions(), operatorCode, nodeIndex);
51 return kTfLiteError;
52 }
53
54 // Get Unpack Num
55 unsigned int unpackNum = NonNegative(params->num, nodeIndex);
56
57 // If num is not defined, automatically infer from the length of the dimension axis.
58 if(unpackNum == 0)
59 {
60 unpackNum = inputTensorInfo.GetShape()[unpackAxis];
61 }
62
63 // If unpack number cannot be inferred and is still zero, return kTfLiteError.
64 if(unpackNum == 0)
65 {
66 TF_LITE_MAYBE_KERNEL_LOG(
67 tfLiteContext,
68 "TfLiteArmnnDelegate: Number to unpack must greater than zero in operator #%d node #%d: ",
69 operatorCode, nodeIndex);
70 return kTfLiteError;
71 }
72
73 // Check outputs
74 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, unpackNum, nodeIndex));
75
76
77 auto inputDimSize = inputTensorInfo.GetNumDimensions();
78 std::vector<unsigned int> unpackDimSizes(inputDimSize);
79
80 // Add current input shape to unpackDimSizes
81 for (unsigned int i = 0; i < inputDimSize; ++i)
82 {
83 unpackDimSizes[i] = inputTensorInfo.GetShape()[i];
84 }
85
86 if (unpackDimSizes[unpackAxis] != unpackNum)
87 {
88 TF_LITE_MAYBE_KERNEL_LOG(
89 tfLiteContext,
90 "TfLiteArmnnDelegate: Number to unpack must be the same as length "
91 "of the dimension to unpack along in operator #%d node #%d: ",
92 operatorCode, nodeIndex);
93 return kTfLiteError;
94 }
95
96 unpackDimSizes[unpackAxis] /= unpackNum;
97
98 armnn::SplitterDescriptor splitDesc(unpackNum, static_cast<unsigned int>(unpackDimSizes.size()));
Mike Kelly363b5722023-10-11 14:25:50 +010099 splitDesc.SetAxis(unpackAxis);
100
Kevin May8ab2d7a2021-05-07 09:32:51 +0100101 for (unsigned int j = 0; j < unpackNum; ++j)
102 {
103 // Set the size of the views.
104 for (unsigned int dimIdx = 0; dimIdx < unpackDimSizes.size(); ++dimIdx)
105 {
106 splitDesc.SetViewSize(j, dimIdx, unpackDimSizes[dimIdx]);
107 }
108 splitDesc.SetViewOriginCoord(j, unpackAxis, unpackDimSizes[unpackAxis] * j);
109 }
110
111 std::vector<armnn::TensorInfo> outputs;
112 for (unsigned int i = 0; i < unpackNum; ++i)
113 {
114 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
115 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
116 {
117 return kTfLiteError;
118 }
Sadik Armagan90a119b2022-08-05 16:12:49 +0100119 outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true));
Kevin May8ab2d7a2021-05-07 09:32:51 +0100120 }
121 const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
122
Kevin May4cad8602021-05-18 09:57:43 +0100123 // Determine the shape of the Splitter layer outputs for validation
124 armnn::TensorShape splitOutShape = armnn::TensorShape(static_cast<unsigned int>(unpackDimSizes.size()),
125 unpackDimSizes.data());
126
127 std::vector<armnn::TensorInfo> splitterOutputs;
128 for (unsigned int outputIndex = 0; outputIndex < outputTensorInfos.size(); ++outputIndex)
129 {
130 splitterOutputs.push_back(armnn::TensorInfo(splitOutShape,
131 outputTensorInfos[outputIndex].get().GetDataType(),
132 outputTensorInfos[outputIndex].get().GetQuantizationScale(),
133 outputTensorInfos[outputIndex].get().GetQuantizationOffset()));
134 }
135 std::vector<std::reference_wrapper<armnn::TensorInfo>> splitterOutputTensorInfos(splitterOutputs.begin(),
136 splitterOutputs.end());
137
Cathal Corbett53837672022-09-01 11:34:37 +0100138 armnn::BackendId setBackendSplit;
Kevin May8ab2d7a2021-05-07 09:32:51 +0100139 if (!delegateData.m_Network)
140 {
Kevin May4cad8602021-05-18 09:57:43 +0100141 // Check if splitter is supported
Kevin May8ab2d7a2021-05-07 09:32:51 +0100142 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000143 FORWARD_LAYER_SUPPORT_FUNC("UNPACK",
Kevin May8ab2d7a2021-05-07 09:32:51 +0100144 tfLiteContext,
145 IsSplitterSupported,
146 delegateData.m_Backends,
147 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100148 setBackendSplit,
Kevin May8ab2d7a2021-05-07 09:32:51 +0100149 inputTensorInfo,
Kevin May4cad8602021-05-18 09:57:43 +0100150 splitterOutputTensorInfos,
Kevin May8ab2d7a2021-05-07 09:32:51 +0100151 splitDesc);
152 return isSupported ? kTfLiteOk : kTfLiteError;
153 }
154
Kevin May4cad8602021-05-18 09:57:43 +0100155 // Create Reshape descriptor from the first outputTensorInfo to validate a single Reshape layer
156 // Use this descriptor later when creating every ReshapeLayer as all Reshape Layers should be the same
157 armnn::ReshapeDescriptor reshapeDescriptor;
158 reshapeDescriptor.m_TargetShape = outputTensorInfos[0].get().GetShape();
159
Cathal Corbett53837672022-09-01 11:34:37 +0100160 armnn::BackendId setBackendReshape;
Kevin May4cad8602021-05-18 09:57:43 +0100161 if (!delegateData.m_Network)
162 {
163 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000164 FORWARD_LAYER_SUPPORT_FUNC("RESHAPE",
Kevin May4cad8602021-05-18 09:57:43 +0100165 tfLiteContext,
166 IsReshapeSupported,
167 delegateData.m_Backends,
168 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100169 setBackendReshape,
Kevin May4cad8602021-05-18 09:57:43 +0100170 splitterOutputTensorInfos[0],
171 outputTensorInfos[0],
172 reshapeDescriptor);
173 return isSupported ? kTfLiteOk : kTfLiteError;
174 };
175
Mike Kelly07169c82023-08-02 13:23:09 +0100176 auto layerName = GetLayerName(armnn::LayerType::Splitter, nodeIndex, "Unpack");
177 armnn::IConnectableLayer* splitterLayer = delegateData.m_Network->AddSplitterLayer(splitDesc, layerName.c_str());
Cathal Corbett53837672022-09-01 11:34:37 +0100178 splitterLayer->SetBackendId(setBackendSplit);
Kevin May8ab2d7a2021-05-07 09:32:51 +0100179 ARMNN_ASSERT(splitterLayer != nullptr);
180
181 for (unsigned int k = 0; k < splitterLayer->GetNumOutputSlots(); ++k)
182 {
183 splitterLayer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
184 }
185
186 // Connect the input slots
187 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(splitterLayer->GetInputSlot(0));
188
Kevin May8ab2d7a2021-05-07 09:32:51 +0100189 // Create reshape to remove the unpacked dimension for unpack operator of each output from Splitter.
190 for (unsigned int outputIndex = 0; outputIndex < splitterLayer->GetNumOutputSlots(); ++outputIndex)
191 {
Mike Kelly07169c82023-08-02 13:23:09 +0100192 auto reshapeName = GetLayerName(armnn::LayerType::Reshape, nodeIndex, "Unpack");
Kevin May8ab2d7a2021-05-07 09:32:51 +0100193 armnn::IConnectableLayer* reshapeLayer = delegateData.m_Network->AddReshapeLayer(reshapeDescriptor,
Mike Kelly07169c82023-08-02 13:23:09 +0100194 reshapeName.c_str());
Cathal Corbett53837672022-09-01 11:34:37 +0100195 reshapeLayer->SetBackendId(setBackendReshape);
Kevin May8ab2d7a2021-05-07 09:32:51 +0100196 ARMNN_ASSERT(reshapeLayer != nullptr);
197
Kevin May4cad8602021-05-18 09:57:43 +0100198 splitterLayer->GetOutputSlot(outputIndex).SetTensorInfo(splitterOutputTensorInfos[outputIndex]);
Kevin May8ab2d7a2021-05-07 09:32:51 +0100199 splitterLayer->GetOutputSlot(outputIndex).Connect(reshapeLayer->GetInputSlot(0));
200
Kevin May4cad8602021-05-18 09:57:43 +0100201 armnn::TensorInfo outputTensorInfo = outputTensorInfos[outputIndex];
Kevin May8ab2d7a2021-05-07 09:32:51 +0100202 reshapeLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
203
204 armnn::IOutputSlot& slot = reshapeLayer->GetOutputSlot(0);
205
206 delegateData.m_OutputSlotForNode[
207 static_cast<unsigned long>(tfLiteNode->outputs->data[outputIndex])] = &slot;
208
209 }
210
211 return kTfLiteOk;
212}
213
214} // namespace armnnDelegate