blob: 4b2bdf376a0e733c3b3fa48b989328e84bb81bf6 [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Teresa Charlin42362962023-04-28 14:23:33 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10namespace armnnOpaqueDelegate
11{
12
13TfLiteStatus VisitTransposeOperator(DelegateData& delegateData,
14 TfLiteOpaqueContext* tfLiteContext,
15 TfLiteOpaqueNode* tfLiteNode,
16 int nodeIndex,
17 int32_t tfliteTransposeOperatorCode)
18{
19 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
20 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
21
22 // Gather input indices and use to get input tensor.
23 const int* inputTensors;
24 int numInputs;
25 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
26 {
27 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
28 tfLiteContext,
29 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
30 nodeIndex);
31 return kTfLiteError;
32 }
33 const TfLiteOpaqueTensor* tfLiteInputTensor0 = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
34 if (!IsValid(tfLiteContext, tfLiteInputTensor0, tfliteTransposeOperatorCode, nodeIndex))
35 {
36 return kTfLiteError;
37 }
38 const TfLiteOpaqueTensor* tfLiteInputTensor1 = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);
39 if (!IsValid(tfLiteContext, tfLiteInputTensor1, tfliteTransposeOperatorCode, nodeIndex))
40 {
41 return kTfLiteError;
42 }
43
44 // Gather output indices and use to get output tensors.
45 const int* outputTensors;
46 int numOutputs;
47 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
48 {
49 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
50 tfLiteContext,
51 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
52 nodeIndex);
53 return kTfLiteError;
54 }
55
56 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
57 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfliteTransposeOperatorCode, nodeIndex))
58 {
59 return kTfLiteError;
60 }
61
62 const armnn::TensorInfo& inputTensorInfo0 = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor0);
63 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
64
65 auto* permTensorDataPtr = static_cast<int32_t*>(TfLiteOpaqueTensorData(tfLiteInputTensor1));
66 unsigned int numEl = TfLiteOpaqueTensorDim(tfLiteInputTensor1, 0);
67
Ryan OSheac229b3f2023-06-27 22:34:54 +010068 if ( numEl > static_cast<int>(armnn::MaxNumOfTensorDimensions) )
69 {
70 return kTfLiteError;
71 }
72
Teresa Charlin42362962023-04-28 14:23:33 +010073 // Ensure only single dimension to the permutation tensor
Ryan OSheac229b3f2023-06-27 22:34:54 +010074 if ( TfLiteOpaqueTensorNumDims(tfLiteInputTensor1) != 1 )
75 {
76 return kTfLiteError;
77 }
Teresa Charlin42362962023-04-28 14:23:33 +010078
79 armnn::TransposeDescriptor descriptor(armnn::PermutationVector(
80 reinterpret_cast<const armnn::PermutationVector::ValueType *> (permTensorDataPtr),
81 static_cast<armnn::PermutationVector::SizeType>(numEl)));
82
83 bool isSupported = false;
84 armnn::BackendId setBackend;
85 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
86 {
87 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("TRANSPOSE",
88 tfLiteContext,
89 IsTransposeSupported,
90 delegateData.m_Backends,
91 isSupported,
92 setBackend,
93 inputTensorInfo0,
94 outputTensorInfo,
95 descriptor);
96 };
97
98 if (!delegateData.m_Network)
99 {
100 validateFunc(outputTensorInfo, isSupported);
101 return isSupported ? kTfLiteOk : kTfLiteError;
102 }
103
Mike Kellya2806502023-08-03 10:42:11 +0100104 auto layerName = GetName(armnn::LayerType::Transpose, nodeIndex);
105 armnn::IConnectableLayer* transposeLayer = delegateData.m_Network->AddTransposeLayer(descriptor, layerName.c_str());
Teresa Charlin42362962023-04-28 14:23:33 +0100106 transposeLayer->SetBackendId(setBackend);
107 ARMNN_ASSERT(transposeLayer != nullptr);
108 // Permutation vector given to descriptor object
Ryan OSheac229b3f2023-06-27 22:34:54 +0100109 if (transposeLayer->GetNumInputSlots() != 1)
110 {
111 return kTfLiteError;
112 }
Teresa Charlin42362962023-04-28 14:23:33 +0100113
114 armnn::IOutputSlot& outputSlot = transposeLayer->GetOutputSlot(0);
115 outputSlot.SetTensorInfo(outputTensorInfo);
116
117 // try to connect the Constant Inputs if there are any
Mike Kellya2806502023-08-03 10:42:11 +0100118 if (ProcessInputs(transposeLayer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Teresa Charlin42362962023-04-28 14:23:33 +0100119 {
120 return kTfLiteError;
121 }
122
123 return Connect(transposeLayer, tfLiteContext, tfLiteNode, delegateData);
124}
125} // namespace armnnOpaqueDelegate