blob: 2627c42f1fe868532d51999c73a980a4e65dadb2 [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Teresa Charlin42362962023-04-28 14:23:33 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10namespace armnnOpaqueDelegate
11{
12
13TfLiteStatus VisitTransposeOperator(DelegateData& delegateData,
14 TfLiteOpaqueContext* tfLiteContext,
15 TfLiteOpaqueNode* tfLiteNode,
16 int nodeIndex,
17 int32_t tfliteTransposeOperatorCode)
18{
19 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
20 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
21
22 // Gather input indices and use to get input tensor.
23 const int* inputTensors;
24 int numInputs;
25 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
26 {
27 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
28 tfLiteContext,
29 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
30 nodeIndex);
31 return kTfLiteError;
32 }
33 const TfLiteOpaqueTensor* tfLiteInputTensor0 = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
34 if (!IsValid(tfLiteContext, tfLiteInputTensor0, tfliteTransposeOperatorCode, nodeIndex))
35 {
36 return kTfLiteError;
37 }
38 const TfLiteOpaqueTensor* tfLiteInputTensor1 = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);
39 if (!IsValid(tfLiteContext, tfLiteInputTensor1, tfliteTransposeOperatorCode, nodeIndex))
40 {
41 return kTfLiteError;
42 }
43
44 // Gather output indices and use to get output tensors.
45 const int* outputTensors;
46 int numOutputs;
47 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
48 {
49 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
50 tfLiteContext,
51 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
52 nodeIndex);
53 return kTfLiteError;
54 }
55
56 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
57 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfliteTransposeOperatorCode, nodeIndex))
58 {
59 return kTfLiteError;
60 }
61
62 const armnn::TensorInfo& inputTensorInfo0 = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor0);
63 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
64
65 auto* permTensorDataPtr = static_cast<int32_t*>(TfLiteOpaqueTensorData(tfLiteInputTensor1));
66 unsigned int numEl = TfLiteOpaqueTensorDim(tfLiteInputTensor1, 0);
67
68 ARMNN_ASSERT( numEl <= static_cast<int>(armnn::MaxNumOfTensorDimensions) );
69 // Ensure only single dimension to the permutation tensor
70 ARMNN_ASSERT( TfLiteOpaqueTensorNumDims(tfLiteInputTensor1) == 1 );
71
72 armnn::TransposeDescriptor descriptor(armnn::PermutationVector(
73 reinterpret_cast<const armnn::PermutationVector::ValueType *> (permTensorDataPtr),
74 static_cast<armnn::PermutationVector::SizeType>(numEl)));
75
76 bool isSupported = false;
77 armnn::BackendId setBackend;
78 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
79 {
80 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("TRANSPOSE",
81 tfLiteContext,
82 IsTransposeSupported,
83 delegateData.m_Backends,
84 isSupported,
85 setBackend,
86 inputTensorInfo0,
87 outputTensorInfo,
88 descriptor);
89 };
90
91 if (!delegateData.m_Network)
92 {
93 validateFunc(outputTensorInfo, isSupported);
94 return isSupported ? kTfLiteOk : kTfLiteError;
95 }
96
97 armnn::IConnectableLayer* transposeLayer = delegateData.m_Network->AddTransposeLayer(descriptor);
98 transposeLayer->SetBackendId(setBackend);
99 ARMNN_ASSERT(transposeLayer != nullptr);
100 // Permutation vector given to descriptor object
101 ARMNN_ASSERT(transposeLayer->GetNumInputSlots() == 1);
102
103 armnn::IOutputSlot& outputSlot = transposeLayer->GetOutputSlot(0);
104 outputSlot.SetTensorInfo(outputTensorInfo);
105
106 // try to connect the Constant Inputs if there are any
107 if(ProcessInputs(transposeLayer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
108 {
109 return kTfLiteError;
110 }
111
112 return Connect(transposeLayer, tfLiteContext, tfLiteNode, delegateData);
113}
114} // namespace armnnOpaqueDelegate